Skip to content

Commit dd5b943

Browse files
authored
feat: Use schema inference code path in Zarr table provider (#21)
1 parent 972cd6f commit dd5b943

8 files changed

Lines changed: 234 additions & 262 deletions

File tree

python/Cargo.lock

Lines changed: 123 additions & 192 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,7 @@ crate-type = ["cdylib"]
2121
datafusion-ffi = "50"
2222
pyo3 = { version = "0.26", features = ["abi3-py39"] }
2323
zarr-datafusion-search = { path = "../" }
24+
zarrs_metadata = "0.6.1"
25+
26+
[patch.crates-io]
27+
zarrs_metadata = { git = "https://github.com/zarrs/zarrs", rev = "6ea0464cf50e481cbae0e89de267c7991ed65f3f" }
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
class ZarrTable:
2-
def __init__(self, path: str) -> None: ...
2+
def __init__(self, zarr_path: str, group_path: str) -> None: ...
33
def __datafusion_table_provider__(self) -> object: ...

python/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#![cfg_attr(not(test), warn(unused_crate_dependencies))]
22

3+
// Use patched version of zarrs-metadata
4+
use zarrs_metadata as _;
5+
36
mod table;
47

58
use pyo3::prelude::*;

python/src/table.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::sync::Arc;
22

33
use datafusion_ffi::table_provider::FFI_TableProvider;
44
use pyo3::prelude::*;
5+
use pyo3::pybacked::PyBackedStr;
56
use pyo3::types::PyCapsule;
67
use zarr_datafusion_search::table_provider::ZarrTableProvider;
78

@@ -11,13 +12,14 @@ pub struct PyZarrTable(Arc<ZarrTableProvider>);
1112
#[pymethods]
1213
impl PyZarrTable {
1314
#[new]
14-
pub fn new(zarr_path: String) -> PyResult<Self> {
15-
let table_provider = ZarrTableProvider::new_filesystem(zarr_path).map_err(|e| {
16-
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
17-
"Failed to create ZarrTableProvider: {}",
18-
e
19-
))
20-
})?;
15+
pub fn new(zarr_path: String, group_path: PyBackedStr) -> PyResult<Self> {
16+
let table_provider =
17+
ZarrTableProvider::new_filesystem(zarr_path, &group_path).map_err(|e| {
18+
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
19+
"Failed to create ZarrTableProvider: {}",
20+
e
21+
))
22+
})?;
2123
Ok(PyZarrTable(Arc::new(table_provider)))
2224
}
2325

python/tests/test_datafusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def test_zarr_scan():
1010
ctx = SessionContext()
1111

1212
zarr_path = ROOT_DIR / "data" / "zarr_store.zarr"
13-
zarr_table = zarr_datafusion_search.ZarrTable(str(zarr_path))
13+
zarr_table = zarr_datafusion_search.ZarrTable(str(zarr_path), "/meta")
1414

1515
ctx.register_table_provider("zarr_data", zarr_table)
1616

src/error.rs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,22 @@ use thiserror::Error;
33

44
#[derive(Error, Debug)]
55
pub enum ZarrDataFusionError {
6-
#[error("DataFusion error: {0}")]
7-
DataFusion(#[from] DataFusionError),
8-
9-
#[error("Zarrs error: {0}")]
10-
Zarrs(#[from] zarrs::array::ArrayError),
6+
// Zarrs errors
7+
#[error("Zarrs array creation error: {0}")]
8+
ArrayCreateError(#[from] zarrs::array::ArrayCreateError),
119

1210
#[error("Zarrs filesystem create error: {0}")]
1311
FilesystemStoreCreateError(#[from] zarrs_filesystem::FilesystemStoreCreateError),
1412

15-
#[error("Zarrs array creation error: {0}")]
16-
ArrayCreateError(#[from] zarrs::array::ArrayCreateError),
13+
#[error("Zarrs group create error: {0}")]
14+
GroupCreateError(#[from] zarrs::group::GroupCreateError),
15+
16+
#[error("Zarrs error: {0}")]
17+
Zarrs(#[from] zarrs::array::ArrayError),
18+
19+
// Other errors
20+
#[error("DataFusion error: {0}")]
21+
DataFusion(#[from] DataFusionError),
1722

1823
#[error("Arrow error: {0}")]
1924
Arrow(#[from] arrow::error::ArrowError),

src/table_provider.rs

Lines changed: 81 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
use arrow_array::{ArrayRef, RecordBatch, StringArray, TimestampMillisecondArray};
2-
use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit};
1+
use arrow_array::{ArrayRef, RecordBatch, StringViewArray, TimestampMillisecondArray};
2+
use arrow_schema::SchemaRef;
33
use async_trait::async_trait;
44
use datafusion::catalog::Session;
55
use datafusion::datasource::{TableProvider, TableType};
@@ -14,19 +14,21 @@ use datafusion::physical_plan::{
1414
SendableRecordBatchStream,
1515
};
1616
use geoarrow_array::GeoArrowArray;
17-
use geoarrow_array::array::WktArray;
18-
use geoarrow_schema::{Crs, WktType};
17+
use geoarrow_array::array::WktViewArray;
18+
use geoarrow_schema::Crs;
1919
use object_store::ObjectStore;
2020
use std::any::Any;
2121
use std::fmt::{self, Debug};
2222
use std::sync::Arc;
2323
use zarrs::array::{Array, ElementOwned};
2424
use zarrs::array_subset::ArraySubset;
25+
use zarrs::group::Group;
2526
use zarrs::storage::{AsyncReadableListableStorageTraits, ReadableListableStorageTraits};
2627
use zarrs_filesystem::{FilesystemStore, FilesystemStoreCreateError};
2728
use zarrs_storage::{MaybeSend, MaybeSync};
2829

2930
use crate::error::ZarrDataFusionResult;
31+
use crate::schema::{group_arrays_schema, group_arrays_schema_async};
3032

3133
/// A simple DataFusion table provider that loads data from a Zarr store
3234
#[derive(Debug)]
@@ -38,40 +40,27 @@ pub struct ZarrTableProvider {
3840
impl ZarrTableProvider {
3941
/// Create a new ZarrTableProvider from a Zarr store path
4042
pub fn new_filesystem<P: AsRef<std::path::Path>>(
41-
zarr_path: P,
42-
) -> Result<Self, Box<dyn std::error::Error>> {
43-
let zarr_backend = ZarrBackend::new_filesystem(zarr_path)?;
44-
let schema = Self::construct_schema();
43+
base_path: P,
44+
group_path: &str,
45+
) -> ZarrDataFusionResult<Self> {
46+
let zarr_backend = SyncZarrBackend::new_filesystem(base_path)?;
47+
let schema = zarr_backend.infer_group_schema(group_path)?;
4548
Ok(Self {
4649
schema,
47-
zarr_backend,
50+
zarr_backend: zarr_backend.into(),
4851
})
4952
}
5053

51-
pub fn new_object_store<T: ObjectStore>(store: T) -> Self {
52-
let zarr_backend = ZarrBackend::new_object_store(store);
53-
let schema = Self::construct_schema();
54-
Self {
54+
pub async fn new_object_store<T: ObjectStore>(
55+
store: T,
56+
group_path: &str,
57+
) -> ZarrDataFusionResult<Self> {
58+
let zarr_backend = AsyncZarrBackend::new_object_store(store);
59+
let schema = zarr_backend.infer_group_schema(group_path).await?;
60+
Ok(Self {
5561
schema,
56-
zarr_backend,
57-
}
58-
}
59-
60-
fn construct_schema() -> SchemaRef {
61-
// Define the schema based on the expected Zarr arrays
62-
let wkt_crs = Crs::from_authority_code("EPSG:4326".to_string());
63-
let wkt_metadata = Arc::new(geoarrow_schema::Metadata::new(wkt_crs, None));
64-
65-
Arc::new(Schema::new(vec![
66-
Field::new("collection", DataType::Utf8, false),
67-
Field::new(
68-
"date",
69-
DataType::Timestamp(TimeUnit::Millisecond, None),
70-
false,
71-
),
72-
Field::new("wkt_field", DataType::Utf8, false)
73-
.with_extension_type(WktType::new(wkt_metadata)),
74-
]))
62+
zarr_backend: zarr_backend.into(),
63+
})
7564
}
7665
}
7766

@@ -108,17 +97,32 @@ impl TableProvider for ZarrTableProvider {
10897
struct SyncZarrBackend(Arc<dyn ReadableListableStorageTraits>);
10998

11099
impl SyncZarrBackend {
100+
fn new_filesystem<P: AsRef<std::path::Path>>(
101+
base_path: P,
102+
) -> Result<Self, FilesystemStoreCreateError> {
103+
Ok(SyncZarrBackend(Arc::new(FilesystemStore::new(base_path)?)))
104+
}
105+
111106
fn load_array<T: ElementOwned>(&self, path: &str) -> ZarrDataFusionResult<Vec<T>> {
112107
let array = Array::open(self.0.clone(), path)?;
113108
let full_subset = ArraySubset::new_with_shape(array.shape().to_vec());
114109
Ok(array.retrieve_array_subset_elements(&full_subset)?)
115110
}
111+
112+
fn infer_group_schema(&self, group_path: &str) -> ZarrDataFusionResult<SchemaRef> {
113+
let group = Group::open(self.0.clone(), group_path)?;
114+
group_arrays_schema(&group)
115+
}
116116
}
117117

118118
#[derive(Clone)]
119119
struct AsyncZarrBackend(Arc<dyn AsyncReadableListableStorageTraits>);
120120

121121
impl AsyncZarrBackend {
122+
fn new_object_store<T: ObjectStore>(store: T) -> Self {
123+
AsyncZarrBackend(Arc::new(zarrs_object_store::AsyncObjectStore::new(store)))
124+
}
125+
122126
async fn load_array<T: ElementOwned + MaybeSend + MaybeSync>(
123127
&self,
124128
path: &str,
@@ -129,37 +133,52 @@ impl AsyncZarrBackend {
129133
.async_retrieve_array_subset_elements(&full_subset)
130134
.await?)
131135
}
136+
137+
async fn infer_group_schema(&self, group_path: &str) -> ZarrDataFusionResult<SchemaRef> {
138+
let group = Group::async_open(self.0.clone(), group_path).await?;
139+
group_arrays_schema_async(&group).await
140+
}
132141
}
133142

134143
#[derive(Clone)]
135144
enum ZarrBackend {
136-
Sync(SyncZarrBackend),
137145
Async(AsyncZarrBackend),
146+
Sync(SyncZarrBackend),
138147
}
139148

140149
impl Debug for ZarrBackend {
141150
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
142151
match self {
143-
ZarrBackend::Sync(_) => write!(f, "ZarrBackend::Sync"),
144152
ZarrBackend::Async(_) => write!(f, "ZarrBackend::Async"),
153+
ZarrBackend::Sync(_) => write!(f, "ZarrBackend::Sync"),
145154
}
146155
}
147156
}
148157

149-
impl ZarrBackend {
150-
fn new_filesystem<P: AsRef<std::path::Path>>(
151-
base_path: P,
152-
) -> Result<Self, FilesystemStoreCreateError> {
153-
Ok(Self::Sync(SyncZarrBackend(Arc::new(FilesystemStore::new(
154-
base_path,
155-
)?))))
158+
impl From<AsyncZarrBackend> for ZarrBackend {
159+
fn from(async_backend: AsyncZarrBackend) -> Self {
160+
ZarrBackend::Async(async_backend)
156161
}
162+
}
157163

158-
fn new_object_store<T: ObjectStore>(store: T) -> Self {
159-
Self::Async(AsyncZarrBackend(Arc::new(
160-
zarrs_object_store::AsyncObjectStore::new(store),
161-
)))
164+
impl From<SyncZarrBackend> for ZarrBackend {
165+
fn from(sync_backend: SyncZarrBackend) -> Self {
166+
ZarrBackend::Sync(sync_backend)
162167
}
168+
}
169+
170+
impl ZarrBackend {
171+
// fn new_filesystem<P: AsRef<std::path::Path>>(
172+
// base_path: P,
173+
// ) -> Result<Self, FilesystemStoreCreateError> {
174+
// Ok(Self::Sync(SyncZarrBackend::new_filesystem(base_path)?))
175+
// }
176+
177+
// fn new_object_store<T: ObjectStore>(store: T) -> Self {
178+
// Self::Async(AsyncZarrBackend(Arc::new(
179+
// zarrs_object_store::AsyncObjectStore::new(store),
180+
// )))
181+
// }
163182

164183
async fn load_array<T: ElementOwned + MaybeSend + MaybeSync>(
165184
&self,
@@ -177,17 +196,25 @@ impl ZarrBackend {
177196
let bbox_data: Vec<String> = self.load_array("/meta/bbox").await?;
178197

179198
// Create Arrow arrays from the loaded data
180-
let collection_arrow: ArrayRef = Arc::new(StringArray::from(collection_data));
199+
let collection_arrow: ArrayRef = Arc::new(StringViewArray::from(collection_data));
181200
let date_arrow: ArrayRef = Arc::new(TimestampMillisecondArray::from(date_data));
182201
let wkt_crs = Crs::from_authority_code("EPSG:4326".to_string());
183202
let wkt_metadata = Arc::new(geoarrow_schema::Metadata::new(wkt_crs, None));
184-
let wkt_arrow = WktArray::new(bbox_data.into(), wkt_metadata);
203+
let wkt_arrow = WktViewArray::new(bbox_data.into(), wkt_metadata);
204+
205+
let columns = schema
206+
.fields()
207+
.iter()
208+
.map(|field| match field.name().as_str() {
209+
"collection" => collection_arrow.clone(),
210+
"date" => date_arrow.clone(),
211+
"bbox" => wkt_arrow.clone().into_array_ref(),
212+
_ => panic!("Unexpected field name: {}", field.name()),
213+
})
214+
.collect();
185215

186216
// Create the RecordBatch
187-
let record_batch = RecordBatch::try_new(
188-
schema,
189-
vec![collection_arrow, date_arrow, wkt_arrow.into_array_ref()],
190-
)?;
217+
let record_batch = RecordBatch::try_new(schema, columns)?;
191218

192219
Ok(record_batch)
193220
}
@@ -279,7 +306,7 @@ mod tests {
279306

280307
#[tokio::test]
281308
async fn test_basic_table_provider() {
282-
let provider = ZarrTableProvider::new_filesystem("data/zarr_store.zarr").unwrap();
309+
let provider = ZarrTableProvider::new_filesystem("data/zarr_store.zarr", "/meta").unwrap();
283310

284311
// Register with DataFusion
285312
let ctx = SessionContext::new();
@@ -300,7 +327,7 @@ mod tests {
300327
#[tokio::test]
301328
#[ignore = "Projection support"]
302329
async fn test_table_provider_with_sql() {
303-
let provider = ZarrTableProvider::new_filesystem("data/zarr_store.zarr").unwrap();
330+
let provider = ZarrTableProvider::new_filesystem("data/zarr_store.zarr", "/meta").unwrap();
304331

305332
// Register with DataFusion
306333
let ctx = SessionContext::new();
@@ -325,7 +352,7 @@ mod tests {
325352
let collection_col = batch
326353
.column(0)
327354
.as_any()
328-
.downcast_ref::<StringArray>()
355+
.downcast_ref::<StringViewArray>()
329356
.unwrap();
330357
assert_eq!(collection_col.value(0), "collection_a");
331358
}

0 commit comments

Comments
 (0)