Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,61 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::atomic::Ordering;

use databend_common_exception::Result;
use databend_common_expression::BlockEntry;
use databend_common_expression::Column;
use databend_common_expression::DataBlock;
use databend_common_expression::RepeatIndex;
use databend_common_expression::ScalarRef;
use databend_common_expression::SortColumnDescription;
use databend_common_expression::Value;
use databend_common_expression::types::AccessType;
use databend_common_expression::types::NumberColumn;
use databend_common_expression::types::NumberScalar;
use databend_common_expression::types::UInt64Type;

use crate::pipelines::processors::transforms::range_join::RangeJoinState;
use crate::pipelines::processors::transforms::range_join::filter_block;

impl RangeJoinState {
pub fn range_join(&self, task_id: usize) -> Result<Vec<DataBlock>> {
// Merge range join originally only served Inner/Cross joins, so it could return
// matched pairs directly without tracking unmatched rows. ASOF LEFT/RIGHT joins now
// reuse this path after the nullable interval-end rewrite, which means we must record
// rows that survive `other_conditions` filtering and then run the existing outer-fill
// tasks for the remaining probe/build rows.
let partition_count = self.partition_count.load(Ordering::SeqCst) as usize;
if task_id >= partition_count {
if !self.left_match.read().is_empty() {
return Ok(vec![self.fill_outer(task_id, true)?]);
} else if !self.right_match.read().is_empty() {
return Ok(vec![self.fill_outer(task_id, false)?]);
}
return Ok(vec![DataBlock::empty()]);
}
let result = self.range_join_partition(task_id);
self.completed_pair.fetch_add(1, Ordering::SeqCst);
result
}

// Used by range join
fn sort_descriptions(&self, _: bool) -> Vec<SortColumnDescription> {
let op = &self.conditions[0].operator;
let asc = match op.as_str() {
"gt" | "gte" => false,
"lt" | "lte" => true,
_ => unreachable!(),
};
vec![SortColumnDescription {
offset: 0,
asc,
nulls_first: true,
}]
}

fn range_join_partition(&self, task_id: usize) -> Result<Vec<DataBlock>> {
let tasks = self.tasks.read();
let (left_idx, right_idx) = tasks[task_id];
let left_sorted_blocks = self.left_sorted_blocks.read();
Expand Down Expand Up @@ -59,6 +102,10 @@ impl RangeJoinState {
let mut result_blocks = Vec::with_capacity(left_len);
let left_table = self.left_table.read();
let right_table = self.right_table.read();
let track_left_outer = !self.left_match.read().is_empty();
let track_right_outer = !self.right_match.read().is_empty();
let mut matched_left = Vec::with_capacity(left_len);
let mut matched_right = Vec::with_capacity(right_len);

while i < left_len {
if j == right_len {
Expand All @@ -79,9 +126,12 @@ impl RangeJoinState {
) {
let mut left_result_block = DataBlock::empty();
let mut right_buffer = Vec::with_capacity(right_len - j);
let mut right_match_buffer = Vec::with_capacity(right_len - j);
let mut left_match_index = None;
if let ScalarRef::Number(NumberScalar::Int64(left)) =
unsafe { left_idx_col.index_unchecked(i) }
{
left_match_index = Some((left - 1) as usize);
left_result_block = left_table[left_idx].take_compacted_indices(
&[RepeatIndex {
row: ((left - 1) as usize - left_offset) as u32,
Expand All @@ -95,39 +145,76 @@ impl RangeJoinState {
unsafe { right_idx_col.index_unchecked(k) }
{
right_buffer.push(((-right - 1) as usize - right_offset) as u32);
if track_right_outer {
right_match_buffer.push(((-right - 1) as usize) as u64);
}
}
}
if !left_result_block.is_empty() {
let right_result_block =
right_table[right_idx].take(right_buffer.as_slice())?;
// Merge left_result_block and right_result_block
left_result_block.merge_block(right_result_block);
if track_right_outer {
left_result_block.add_entry(BlockEntry::new(
Value::Column(Column::Number(NumberColumn::UInt64(
right_match_buffer.into(),
))),
|| {
(
databend_common_expression::types::DataType::Number(
databend_common_expression::types::NumberDataType::UInt64,
),
left_result_block.num_rows(),
)
},
));
}
for filter in self.other_conditions.iter() {
left_result_block = filter_block(left_result_block, filter)?;
}
if track_left_outer && !left_result_block.is_empty() {
if let Some(left_match_index) = left_match_index {
matched_left.push(left_match_index);
}
}
if track_right_outer && !left_result_block.is_empty() {
let column = &left_result_block
.columns()
.last()
.unwrap()
.value()
.try_downcast::<UInt64Type>()
.unwrap();
if let Value::Column(col) = column {
matched_right
.extend(UInt64Type::iter_column(col).map(|idx| idx as usize));
}
left_result_block.pop_columns(1);
}
result_blocks.push(left_result_block);
}
i += 1;
} else {
j += 1;
}
}
Ok(result_blocks)
}

// Used by range join
fn sort_descriptions(&self, _: bool) -> Vec<SortColumnDescription> {
let op = &self.conditions[0].operator;
let asc = match op.as_str() {
"gt" | "gte" => false,
"lt" | "lte" => true,
_ => unreachable!(),
};
vec![SortColumnDescription {
offset: 0,
asc,
nulls_first: true,
}]
if track_left_outer && !matched_left.is_empty() {
let mut left_match = self.left_match.write();
for idx in matched_left {
left_match.set(idx, true);
}
}

if track_right_outer && !matched_right.is_empty() {
let mut right_match = self.right_match.write();
for idx in matched_right {
right_match.set(idx, true);
}
}

Ok(result_blocks)
}
}

Expand Down
175 changes: 127 additions & 48 deletions src/query/sql/src/planner/binder/bind_table_reference/bind_asof_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,40 +84,32 @@ impl Binder {
std::mem::swap(&mut condition.left, &mut condition.right)
}

let span = right_column.span();
let arguments = [
right_column,
BoundColumnRef {
span,
column: ColumnBindingBuilder::new(
window_func.display_name.clone(),
window_info.index,
Box::new(window_func.func.return_type()),
Visibility::Visible,
)
.build(),
}
.into(),
]
.to_vec();

let func_name = match range_func.func_name.as_str() {
GTE => LT,
GT => LTE,
LT => GTE,
LTE => GT,
_ => unreachable!(),
};
let span = right_column.span();
let lead_column = BoundColumnRef {
span,
column: ColumnBindingBuilder::new(
window_func.display_name.clone(),
window_info.index,
Box::new(window_func.func.return_type()),
Visibility::Visible,
)
.build(),
}
.to_string();
join.non_equi_conditions.push(
FunctionCall {
span: range_func.span,
params: vec![],
arguments,
.into();
join.non_equi_conditions
.push(make_asof_interval_end_condition(
range_func.span,
right_column,
lead_column,
func_name,
}
.into(),
);
));

let window_plan = bind_window_function_info(&self.ctx, window_info, right)?;
Ok(SExpr::create_binary(
Expand Down Expand Up @@ -158,26 +150,6 @@ impl Binder {
_ => unreachable!(),
};

let constant_default = {
let value = if asc {
left_column
.data_type()?
.remove_nullable()
.infinity()
.unwrap()
} else {
left_column
.data_type()?
.remove_nullable()
.ninfinity()
.unwrap()
};
ConstantExpr {
span: left_column.span(),
value,
}
};

let order_items = vec![WindowOrderBy {
expr: left_column.clone(),
asc: Some(asc),
Expand All @@ -189,12 +161,13 @@ impl Binder {
partition_items.push(condition.right.clone());
}

let return_type = asof_window_result_type(&left_column.data_type()?);
let func_type = WindowFuncType::LagLead(LagLeadFunction {
is_lag: false,
return_type: Box::new(left_column.data_type()?.clone()),
return_type: Box::new(return_type),
arg: Box::new(left_column),
offset: 1,
default: Some(Box::new(constant_default.into())),
default: None,
});

let window_func = WindowFunc {
Expand All @@ -217,6 +190,45 @@ impl Binder {
}
}

fn asof_window_result_type(
data_type: &databend_common_expression::types::DataType,
) -> databend_common_expression::types::DataType {
data_type.wrap_nullable()
}

fn make_asof_interval_end_condition(
span: databend_common_ast::Span,
probe_key: ScalarExpr,
lead_key: ScalarExpr,
func_name: &str,
) -> ScalarExpr {
let compare = ScalarExpr::FunctionCall(FunctionCall {
span,
func_name: func_name.to_string(),
params: vec![],
arguments: vec![probe_key, lead_key.clone()],
});

ScalarExpr::FunctionCall(FunctionCall {
span,
func_name: "if".to_string(),
params: vec![],
arguments: vec![
ScalarExpr::FunctionCall(FunctionCall {
span,
func_name: "is_not_null".to_string(),
params: vec![],
arguments: vec![lead_key],
}),
compare,
ScalarExpr::ConstantExpr(ConstantExpr {
span,
value: Scalar::Boolean(true),
}),
],
})
}

pub fn is_range_join_condition<'a>(
expr: &'a ScalarExpr,
left_prop: &RelationalProperty,
Expand Down Expand Up @@ -252,3 +264,70 @@ pub fn is_range_join_condition<'a>(
_ => None,
}
}

#[cfg(test)]
mod tests {
use databend_common_expression::types::DataType;
use databend_common_expression::types::NumberDataType;

use super::*;
use crate::Symbol;

fn test_column(name: &str, index: usize, data_type: DataType) -> ScalarExpr {
BoundColumnRef {
span: None,
column: ColumnBindingBuilder::new(
name.to_string(),
Symbol::from_field_index(index),
Box::new(data_type),
Visibility::Visible,
)
.build(),
}
.into()
}

#[test]
fn test_asof_interval_end_condition_guards_open_tail_with_null_lead() {
let probe = test_column("probe", 0, DataType::Number(NumberDataType::UInt8));
let lead = test_column(
"lead",
1,
DataType::Number(NumberDataType::UInt8).wrap_nullable(),
);

let expr = make_asof_interval_end_condition(None, probe.clone(), lead.clone(), LT);
let ScalarExpr::FunctionCall(func) = expr else {
panic!("expected function call");
};

assert_eq!(func.func_name, "if");
assert_eq!(func.arguments.len(), 3);

let ScalarExpr::FunctionCall(not_null) = &func.arguments[0] else {
panic!("expected is_not_null guard");
};
assert_eq!(not_null.func_name, "is_not_null");
assert_eq!(not_null.arguments, vec![lead.clone()]);

let ScalarExpr::FunctionCall(compare) = &func.arguments[1] else {
panic!("expected comparison branch");
};
assert_eq!(compare.func_name, LT);
assert_eq!(compare.arguments, vec![probe, lead]);

let ScalarExpr::ConstantExpr(constant) = &func.arguments[2] else {
panic!("expected constant true branch");
};
assert_eq!(constant.value, Scalar::Boolean(true));
}

#[test]
fn test_asof_window_result_type_is_nullable() {
let data_type = DataType::Number(NumberDataType::UInt8);
assert_eq!(
asof_window_result_type(&data_type),
data_type.wrap_nullable()
);
}
}
Loading
Loading