Skip to content
6 changes: 5 additions & 1 deletion datafusion/common/src/dfschema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,11 @@ impl TryFrom<SchemaRef> for DFSchema {
field_qualifiers: vec![None; field_count],
functional_dependencies: FunctionalDependencies::empty(),
};
dfschema.check_names()?;
// Without checking names, because schema here may have duplicate field names.
// For example, Partial AggregateMode will generate duplicate field names from
// state_fields.
// See <https://github.com/apache/datafusion/issues/17715>
// dfschema.check_names()?;
Ok(dfschema)
}
}
Expand Down
8 changes: 8 additions & 0 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1727,6 +1727,14 @@ impl FunctionRegistry for SessionContext {
) -> Result<()> {
self.state.write().register_expr_planner(expr_planner)
}

fn udafs(&self) -> HashSet<String> {
self.state.read().udafs()
}

fn udwfs(&self) -> HashSet<String> {
self.state.read().udwfs()
}
}

/// Create a new task context instance from SessionContext
Expand Down
8 changes: 8 additions & 0 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1881,6 +1881,14 @@ impl FunctionRegistry for SessionState {
self.expr_planners.push(expr_planner);
Ok(())
}

fn udafs(&self) -> HashSet<String> {
self.aggregate_functions.keys().cloned().collect()
}

fn udwfs(&self) -> HashSet<String> {
self.window_functions.keys().cloned().collect()
}
}

impl OptimizerConfig for SessionState {
Expand Down
115 changes: 112 additions & 3 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use arrow::datatypes::{
};
use arrow::error::ArrowError;
use arrow::util::pretty::pretty_format_batches;
use arrow_schema::{SortOptions, TimeUnit};
use datafusion::{assert_batches_eq, dataframe};
use datafusion_functions_aggregate::count::{count_all, count_all_window};
use datafusion_functions_aggregate::expr_fn::{
Expand Down Expand Up @@ -64,8 +65,8 @@ use datafusion::test_util::{
use datafusion_catalog::TableProvider;
use datafusion_common::test_util::{batches_to_sort_string, batches_to_string};
use datafusion_common::{
assert_contains, Constraint, Constraints, DataFusionError, ParamValues, ScalarValue,
TableReference, UnnestOptions,
assert_contains, Constraint, Constraints, DFSchema, DataFusionError, ParamValues,
ScalarValue, TableReference, UnnestOptions,
};
use datafusion_common_runtime::SpawnedTask;
use datafusion_datasource::file_format::format_as_file_type;
Expand All @@ -79,10 +80,16 @@ use datafusion_expr::{
LogicalPlanBuilder, ScalarFunctionImplementation, SortExpr, WindowFrame,
WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::Partitioning;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use datafusion_physical_plan::{displayable, ExecutionPlanProperties};
use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
use datafusion_physical_plan::aggregates::{
AggregateExec, AggregateMode, PhysicalGroupBy,
};
use datafusion_physical_plan::empty::EmptyExec;
use datafusion_physical_plan::{displayable, ExecutionPlan, ExecutionPlanProperties};

// Get string representation of the plan
async fn physical_plan_to_string(df: &DataFrame) -> String {
Expand Down Expand Up @@ -6322,3 +6329,105 @@ async fn test_copy_to_preserves_order() -> Result<()> {
);
Ok(())
}

#[tokio::test]
async fn test_duplicate_state_fields_for_dfschema_construct() -> Result<()> {
let ctx = SessionContext::new();

// Simple schema with just the fields we need
let file_schema = Arc::new(Schema::new(vec![
Field::new(
"timestamp",
DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
true,
),
Field::new("ticker", DataType::Utf8, true),
Field::new("value", DataType::Float64, true),
Field::new("date", DataType::Utf8, false),
]));

let df_schema = DFSchema::try_from(file_schema.clone())?;

let timestamp = col("timestamp");
let value = col("value");
let ticker = col("ticker");
let date = col("date");

let mock_exec = Arc::new(EmptyExec::new(file_schema.clone()));

// Build first_value aggregate
let first_value = Arc::new(
AggregateExprBuilder::new(
datafusion_functions_aggregate::first_last::first_value_udaf(),
vec![ctx.create_physical_expr(value.clone(), &df_schema)?],
)
.alias("first_value(value)")
.order_by(vec![PhysicalSortExpr::new(
ctx.create_physical_expr(timestamp.clone(), &df_schema)?,
SortOptions::new(false, false),
)])
.schema(file_schema.clone())
.build()
.expect("Failed to build first_value"),
);

// Build last_value aggregate
let last_value = Arc::new(
AggregateExprBuilder::new(
datafusion_functions_aggregate::first_last::last_value_udaf(),
vec![ctx.create_physical_expr(value.clone(), &df_schema)?],
)
.alias("last_value(value)")
.order_by(vec![PhysicalSortExpr::new(
ctx.create_physical_expr(timestamp.clone(), &df_schema)?,
SortOptions::new(false, false),
)])
.schema(file_schema.clone())
.build()
.expect("Failed to build last_value"),
);

let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::new_single(vec![
(
ctx.create_physical_expr(date.clone(), &df_schema)?,
"date".to_string(),
),
(
ctx.create_physical_expr(ticker.clone(), &df_schema)?,
"ticker".to_string(),
),
]),
vec![first_value, last_value],
vec![None, None],
mock_exec,
file_schema,
)
.expect("Failed to build partial agg");

// Assert that the schema field names match the expected names
let expected_field_names = vec![
"date",
"ticker",
"first_value(value)[first_value]",
"timestamp@0",
"is_set",
"last_value(value)[last_value]",
"timestamp@0",
"is_set",
];

let binding = partial_agg.schema();
let actual_field_names: Vec<_> = binding.fields().iter().map(|f| f.name()).collect();
assert_eq!(actual_field_names, expected_field_names);

// Ensure that DFSchema::try_from does not fail
let partial_agg_exec_schema = DFSchema::try_from(partial_agg.schema());
assert!(
partial_agg_exec_schema.is_ok(),
"Expected get AggregateExec schema to succeed with duplicate state fields"
);

Ok(())
}
5 changes: 3 additions & 2 deletions datafusion/datasource-parquet/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ impl ParquetFileReaderFactory for CachedParquetFileReaderFactory {
file_metrics,
file_meta,
metadata_cache: Arc::clone(&self.metadata_cache),
metadata_size_hint,
}))
}
}
Expand All @@ -222,6 +223,7 @@ pub struct CachedParquetFileReader {
pub inner: ParquetObjectReader,
file_meta: FileMeta,
metadata_cache: Arc<dyn FileMetadataCache>,
metadata_size_hint: Option<usize>,
}

impl AsyncFileReader for CachedParquetFileReader {
Expand Down Expand Up @@ -261,11 +263,10 @@ impl AsyncFileReader for CachedParquetFileReader {
#[cfg(not(feature = "parquet_encryption"))]
let file_decryption_properties = None;

// TODO there should be metadata prefetch hint here
// https://github.com/apache/datafusion/issues/17279
DFParquetMetadata::new(&self.store, &file_meta.object_meta)
.with_decryption_properties(file_decryption_properties)
.with_file_metadata_cache(Some(Arc::clone(&metadata_cache)))
.with_metadata_size_hint(self.metadata_size_hint)
.fetch_metadata()
.await
.map_err(|e| {
Expand Down
73 changes: 71 additions & 2 deletions datafusion/datasource/src/file_scan_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,11 @@ impl DataSource for FileScanConfig {
// Note that this will *ignore* any non-projected columns: these don't factor into ordering / equivalence.
match reassign_predicate_columns(filter, &schema, true) {
Ok(filter) => {
match Self::add_filter_equivalence_info(filter, &mut eq_properties) {
match Self::add_filter_equivalence_info(
filter,
&mut eq_properties,
&schema,
) {
Ok(()) => {}
Err(e) => {
warn!("Failed to add filter equivalence info: {e}");
Expand Down Expand Up @@ -758,9 +762,24 @@ impl FileScanConfig {
fn add_filter_equivalence_info(
filter: Arc<dyn PhysicalExpr>,
eq_properties: &mut EquivalenceProperties,
schema: &Schema,
) -> Result<()> {
macro_rules! ignore_dangling_col {
($col:expr) => {
if let Some(col) = $col.as_any().downcast_ref::<Column>() {
if schema.index_of(col.name()).is_err() {
continue;
}
}
};
}

let (equal_pairs, _) = collect_columns_from_predicate(&filter);
for (lhs, rhs) in equal_pairs {
// Ignore any binary expressions that reference non-existent columns in the current schema
// (e.g. due to unnecessary projections being removed)
ignore_dangling_col!(lhs);
ignore_dangling_col!(rhs);
eq_properties.add_equal_conditions(Arc::clone(lhs), Arc::clone(rhs))?
}
Ok(())
Expand Down Expand Up @@ -1449,6 +1468,7 @@ pub fn wrap_partition_value_in_dict(val: ScalarValue) -> ScalarValue {
#[cfg(test)]
mod tests {
use super::*;
use crate::test_util::col;
use crate::{
generate_test_files, test_util::MockSource, tests::aggr_test_schema,
verify_sort_integrity,
Expand All @@ -1457,8 +1477,9 @@ mod tests {
use arrow::array::{Int32Array, RecordBatch};
use datafusion_common::stats::Precision;
use datafusion_common::{assert_batches_eq, internal_err};
use datafusion_expr::SortExpr;
use datafusion_expr::{Operator, SortExpr};
use datafusion_physical_expr::create_physical_sort_expr;
use datafusion_physical_expr::expressions::{BinaryExpr, Literal};

/// Returns the column names on the schema
pub fn columns(schema: &Schema) -> Vec<String> {
Expand Down Expand Up @@ -2214,6 +2235,54 @@ mod tests {
assert_eq!(config.output_ordering.len(), 1);
}

#[test]
fn equivalence_properties_after_schema_change() {
let file_schema = aggr_test_schema();
let object_store_url = ObjectStoreUrl::parse("test:///").unwrap();
// Create a file source with a filter
let file_source: Arc<dyn FileSource> =
Arc::new(MockSource::default().with_filter(Arc::new(BinaryExpr::new(
col("c2", &file_schema).unwrap(),
Operator::Eq,
Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
))));

let config = FileScanConfigBuilder::new(
object_store_url.clone(),
Arc::clone(&file_schema),
Arc::clone(&file_source),
)
.with_projection(Some(vec![0, 1, 2]))
.build();

// Simulate projection being updated. Since the filter has already been pushed down,
// the new projection won't include the filtered column.
let data_source = config
.try_swapping_with_projection(&[ProjectionExpr::new(
col("c3", &file_schema).unwrap(),
"c3".to_string(),
)])
.unwrap()
.unwrap();

// Gather the equivalence properties from the new data source. There should
// be no equivalence class for column c2 since it was removed by the projection.
let eq_properties = data_source.eq_properties();
let eq_group = eq_properties.eq_group();

for class in eq_group.iter() {
for expr in class.iter() {
if let Some(col) = expr.as_any().downcast_ref::<Column>() {
assert_ne!(
col.name(),
"c2",
"c2 should not be present in any equivalence class"
);
}
}
}
}

#[test]
fn test_file_scan_config_builder_defaults() {
let file_schema = aggr_test_schema();
Expand Down
12 changes: 12 additions & 0 deletions datafusion/datasource/src/test_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ pub(crate) struct MockSource {
metrics: ExecutionPlanMetricsSet,
projected_statistics: Option<Statistics>,
schema_adapter_factory: Option<Arc<dyn SchemaAdapterFactory>>,
filter: Option<Arc<dyn PhysicalExpr>>,
}

impl MockSource {
pub fn with_filter(mut self, filter: Arc<dyn PhysicalExpr>) -> Self {
self.filter = Some(filter);
self
}
}

impl FileSource for MockSource {
Expand All @@ -50,6 +58,10 @@ impl FileSource for MockSource {
self
}

fn filter(&self) -> Option<Arc<dyn PhysicalExpr>> {
self.filter.clone()
}

fn with_batch_size(&self, _batch_size: usize) -> Arc<dyn FileSource> {
Arc::new(Self { ..self.clone() })
}
Expand Down
8 changes: 8 additions & 0 deletions datafusion/execution/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ impl FunctionRegistry for TaskContext {
fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
vec![]
}

fn udafs(&self) -> HashSet<String> {
self.aggregate_functions.keys().cloned().collect()
}

fn udwfs(&self) -> HashSet<String> {
self.window_functions.keys().cloned().collect()
}
}

#[cfg(test)]
Expand Down
Loading
Loading