Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 123 additions & 192 deletions python/Cargo.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ crate-type = ["cdylib"]
datafusion-ffi = "50"
pyo3 = { version = "0.26", features = ["abi3-py39"] }
zarr-datafusion-search = { path = "../" }
zarrs_metadata = "0.6.1"

[patch.crates-io]
zarrs_metadata = { git = "https://github.com/zarrs/zarrs", rev = "6ea0464cf50e481cbae0e89de267c7991ed65f3f" }
2 changes: 1 addition & 1 deletion python/python/zarr_datafusion_search/_rust.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
class ZarrTable:
def __init__(self, path: str) -> None: ...
def __init__(self, zarr_path: str, group_path: str) -> None: ...
def __datafusion_table_provider__(self) -> object: ...
3 changes: 3 additions & 0 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#![cfg_attr(not(test), warn(unused_crate_dependencies))]

// Use patched version of zarrs-metadata
use zarrs_metadata as _;

mod table;

use pyo3::prelude::*;
Expand Down
16 changes: 9 additions & 7 deletions python/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::sync::Arc;

use datafusion_ffi::table_provider::FFI_TableProvider;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::PyCapsule;
use zarr_datafusion_search::table_provider::ZarrTableProvider;

Expand All @@ -11,13 +12,14 @@ pub struct PyZarrTable(Arc<ZarrTableProvider>);
#[pymethods]
impl PyZarrTable {
#[new]
pub fn new(zarr_path: String) -> PyResult<Self> {
let table_provider = ZarrTableProvider::new_filesystem(zarr_path).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
"Failed to create ZarrTableProvider: {}",
e
))
})?;
pub fn new(zarr_path: String, group_path: PyBackedStr) -> PyResult<Self> {
let table_provider =
ZarrTableProvider::new_filesystem(zarr_path, &group_path).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
"Failed to create ZarrTableProvider: {}",
e
))
})?;
Ok(PyZarrTable(Arc::new(table_provider)))
}

Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_zarr_scan():
ctx = SessionContext()

zarr_path = ROOT_DIR / "data" / "zarr_store.zarr"
zarr_table = zarr_datafusion_search.ZarrTable(str(zarr_path))
zarr_table = zarr_datafusion_search.ZarrTable(str(zarr_path), "/meta")

ctx.register_table_provider("zarr_data", zarr_table)

Expand Down
19 changes: 12 additions & 7 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@ use thiserror::Error;

#[derive(Error, Debug)]
pub enum ZarrDataFusionError {
#[error("DataFusion error: {0}")]
DataFusion(#[from] DataFusionError),

#[error("Zarrs error: {0}")]
Zarrs(#[from] zarrs::array::ArrayError),
// Zarrs errors
#[error("Zarrs array creation error: {0}")]
ArrayCreateError(#[from] zarrs::array::ArrayCreateError),

#[error("Zarrs filesystem create error: {0}")]
FilesystemStoreCreateError(#[from] zarrs_filesystem::FilesystemStoreCreateError),

#[error("Zarrs array creation error: {0}")]
ArrayCreateError(#[from] zarrs::array::ArrayCreateError),
#[error("Zarrs group create error: {0}")]
GroupCreateError(#[from] zarrs::group::GroupCreateError),

#[error("Zarrs error: {0}")]
Zarrs(#[from] zarrs::array::ArrayError),

// Other errors
#[error("DataFusion error: {0}")]
DataFusion(#[from] DataFusionError),

#[error("Arrow error: {0}")]
Arrow(#[from] arrow::error::ArrowError),
Expand Down
135 changes: 81 additions & 54 deletions src/table_provider.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use arrow_array::{ArrayRef, RecordBatch, StringArray, TimestampMillisecondArray};
use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit};
use arrow_array::{ArrayRef, RecordBatch, StringViewArray, TimestampMillisecondArray};
use arrow_schema::SchemaRef;
use async_trait::async_trait;
use datafusion::catalog::Session;
use datafusion::datasource::{TableProvider, TableType};
Expand All @@ -14,19 +14,21 @@ use datafusion::physical_plan::{
SendableRecordBatchStream,
};
use geoarrow_array::GeoArrowArray;
use geoarrow_array::array::WktArray;
use geoarrow_schema::{Crs, WktType};
use geoarrow_array::array::WktViewArray;
use geoarrow_schema::Crs;
use object_store::ObjectStore;
use std::any::Any;
use std::fmt::{self, Debug};
use std::sync::Arc;
use zarrs::array::{Array, ElementOwned};
use zarrs::array_subset::ArraySubset;
use zarrs::group::Group;
use zarrs::storage::{AsyncReadableListableStorageTraits, ReadableListableStorageTraits};
use zarrs_filesystem::{FilesystemStore, FilesystemStoreCreateError};
use zarrs_storage::{MaybeSend, MaybeSync};

use crate::error::ZarrDataFusionResult;
use crate::schema::{group_arrays_schema, group_arrays_schema_async};

/// A simple DataFusion table provider that loads data from a Zarr store
#[derive(Debug)]
Expand All @@ -38,40 +40,27 @@ pub struct ZarrTableProvider {
impl ZarrTableProvider {
/// Create a new ZarrTableProvider from a Zarr store path
pub fn new_filesystem<P: AsRef<std::path::Path>>(
zarr_path: P,
) -> Result<Self, Box<dyn std::error::Error>> {
let zarr_backend = ZarrBackend::new_filesystem(zarr_path)?;
let schema = Self::construct_schema();
base_path: P,
group_path: &str,
) -> ZarrDataFusionResult<Self> {
let zarr_backend = SyncZarrBackend::new_filesystem(base_path)?;
let schema = zarr_backend.infer_group_schema(group_path)?;
Ok(Self {
schema,
zarr_backend,
zarr_backend: zarr_backend.into(),
})
}

pub fn new_object_store<T: ObjectStore>(store: T) -> Self {
let zarr_backend = ZarrBackend::new_object_store(store);
let schema = Self::construct_schema();
Self {
pub async fn new_object_store<T: ObjectStore>(
store: T,
group_path: &str,
) -> ZarrDataFusionResult<Self> {
let zarr_backend = AsyncZarrBackend::new_object_store(store);
let schema = zarr_backend.infer_group_schema(group_path).await?;
Ok(Self {
schema,
zarr_backend,
}
}

fn construct_schema() -> SchemaRef {
// Define the schema based on the expected Zarr arrays
let wkt_crs = Crs::from_authority_code("EPSG:4326".to_string());
let wkt_metadata = Arc::new(geoarrow_schema::Metadata::new(wkt_crs, None));

Arc::new(Schema::new(vec![
Field::new("collection", DataType::Utf8, false),
Field::new(
"date",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new("wkt_field", DataType::Utf8, false)
.with_extension_type(WktType::new(wkt_metadata)),
]))
zarr_backend: zarr_backend.into(),
})
}
}

Expand Down Expand Up @@ -108,17 +97,32 @@ impl TableProvider for ZarrTableProvider {
struct SyncZarrBackend(Arc<dyn ReadableListableStorageTraits>);

impl SyncZarrBackend {
fn new_filesystem<P: AsRef<std::path::Path>>(
base_path: P,
) -> Result<Self, FilesystemStoreCreateError> {
Ok(SyncZarrBackend(Arc::new(FilesystemStore::new(base_path)?)))
}

fn load_array<T: ElementOwned>(&self, path: &str) -> ZarrDataFusionResult<Vec<T>> {
let array = Array::open(self.0.clone(), path)?;
let full_subset = ArraySubset::new_with_shape(array.shape().to_vec());
Ok(array.retrieve_array_subset_elements(&full_subset)?)
}

fn infer_group_schema(&self, group_path: &str) -> ZarrDataFusionResult<SchemaRef> {
let group = Group::open(self.0.clone(), group_path)?;
group_arrays_schema(&group)
}
}

#[derive(Clone)]
struct AsyncZarrBackend(Arc<dyn AsyncReadableListableStorageTraits>);

impl AsyncZarrBackend {
fn new_object_store<T: ObjectStore>(store: T) -> Self {
AsyncZarrBackend(Arc::new(zarrs_object_store::AsyncObjectStore::new(store)))
}

async fn load_array<T: ElementOwned + MaybeSend + MaybeSync>(
&self,
path: &str,
Expand All @@ -129,37 +133,52 @@ impl AsyncZarrBackend {
.async_retrieve_array_subset_elements(&full_subset)
.await?)
}

async fn infer_group_schema(&self, group_path: &str) -> ZarrDataFusionResult<SchemaRef> {
let group = Group::async_open(self.0.clone(), group_path).await?;
group_arrays_schema_async(&group).await
}
}

#[derive(Clone)]
enum ZarrBackend {
Sync(SyncZarrBackend),
Async(AsyncZarrBackend),
Sync(SyncZarrBackend),
}

impl Debug for ZarrBackend {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ZarrBackend::Sync(_) => write!(f, "ZarrBackend::Sync"),
ZarrBackend::Async(_) => write!(f, "ZarrBackend::Async"),
ZarrBackend::Sync(_) => write!(f, "ZarrBackend::Sync"),
}
}
}

impl ZarrBackend {
fn new_filesystem<P: AsRef<std::path::Path>>(
base_path: P,
) -> Result<Self, FilesystemStoreCreateError> {
Ok(Self::Sync(SyncZarrBackend(Arc::new(FilesystemStore::new(
base_path,
)?))))
impl From<AsyncZarrBackend> for ZarrBackend {
fn from(async_backend: AsyncZarrBackend) -> Self {
ZarrBackend::Async(async_backend)
}
}

fn new_object_store<T: ObjectStore>(store: T) -> Self {
Self::Async(AsyncZarrBackend(Arc::new(
zarrs_object_store::AsyncObjectStore::new(store),
)))
impl From<SyncZarrBackend> for ZarrBackend {
fn from(sync_backend: SyncZarrBackend) -> Self {
ZarrBackend::Sync(sync_backend)
}
}

impl ZarrBackend {
// fn new_filesystem<P: AsRef<std::path::Path>>(
// base_path: P,
// ) -> Result<Self, FilesystemStoreCreateError> {
// Ok(Self::Sync(SyncZarrBackend::new_filesystem(base_path)?))
// }

// fn new_object_store<T: ObjectStore>(store: T) -> Self {
// Self::Async(AsyncZarrBackend(Arc::new(
// zarrs_object_store::AsyncObjectStore::new(store),
// )))
// }

async fn load_array<T: ElementOwned + MaybeSend + MaybeSync>(
&self,
Expand All @@ -177,17 +196,25 @@ impl ZarrBackend {
let bbox_data: Vec<String> = self.load_array("/meta/bbox").await?;

// Create Arrow arrays from the loaded data
let collection_arrow: ArrayRef = Arc::new(StringArray::from(collection_data));
let collection_arrow: ArrayRef = Arc::new(StringViewArray::from(collection_data));
let date_arrow: ArrayRef = Arc::new(TimestampMillisecondArray::from(date_data));
let wkt_crs = Crs::from_authority_code("EPSG:4326".to_string());
let wkt_metadata = Arc::new(geoarrow_schema::Metadata::new(wkt_crs, None));
let wkt_arrow = WktArray::new(bbox_data.into(), wkt_metadata);
let wkt_arrow = WktViewArray::new(bbox_data.into(), wkt_metadata);

let columns = schema
.fields()
.iter()
.map(|field| match field.name().as_str() {
"collection" => collection_arrow.clone(),
"date" => date_arrow.clone(),
"bbox" => wkt_arrow.clone().into_array_ref(),
_ => panic!("Unexpected field name: {}", field.name()),
})
.collect();

// Create the RecordBatch
let record_batch = RecordBatch::try_new(
schema,
vec![collection_arrow, date_arrow, wkt_arrow.into_array_ref()],
)?;
let record_batch = RecordBatch::try_new(schema, columns)?;

Ok(record_batch)
}
Expand Down Expand Up @@ -279,7 +306,7 @@ mod tests {

#[tokio::test]
async fn test_basic_table_provider() {
let provider = ZarrTableProvider::new_filesystem("data/zarr_store.zarr").unwrap();
let provider = ZarrTableProvider::new_filesystem("data/zarr_store.zarr", "/meta").unwrap();

// Register with DataFusion
let ctx = SessionContext::new();
Expand All @@ -300,7 +327,7 @@ mod tests {
#[tokio::test]
#[ignore = "Projection support"]
async fn test_table_provider_with_sql() {
let provider = ZarrTableProvider::new_filesystem("data/zarr_store.zarr").unwrap();
let provider = ZarrTableProvider::new_filesystem("data/zarr_store.zarr", "/meta").unwrap();

// Register with DataFusion
let ctx = SessionContext::new();
Expand All @@ -325,7 +352,7 @@ mod tests {
let collection_col = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.downcast_ref::<StringViewArray>()
.unwrap();
assert_eq!(collection_col.value(0), "collection_a");
}
Expand Down