Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ public void extractUsNsTest() {
}

@Test
public void extractTimeFilterPushDownTest() {
public void extractFilterPushDownTest() {
String[] expectedHeader = new String[] {"time", "s1"};
String[] retArray =
new String[] {
Expand All @@ -129,23 +129,34 @@ public void extractTimeFilterPushDownTest() {
expectedHeader,
retArray,
DATABASE_NAME);
// verify the pushdown result is same with non-pushdown
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from ts) > 1",
expectedHeader,
retArray,
DATABASE_NAME);
// verify the pushdown result is same with non-pushdown
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from ts) + 1 > 2",
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from time) >= 2",
expectedHeader,
retArray,
DATABASE_NAME);
// verify the pushdown result is same with non-pushdown
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from ts) >= 2",
expectedHeader,
retArray,
DATABASE_NAME);
// verify the pushdown result is same with non-pushdown
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from ts) + 1>= 3",
expectedHeader,
retArray,
DATABASE_NAME);

retArray =
new String[] {
getTimeStrUTC8("2025-07-08T10:18:51") + ",2,",
Expand All @@ -162,6 +173,12 @@ public void extractTimeFilterPushDownTest() {
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from ts) + 1 > 10",
"+08:00",
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from time) >= 10",
"+08:00",
Expand All @@ -174,6 +191,12 @@ public void extractTimeFilterPushDownTest() {
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from ts)+1>= 11",
"+08:00",
expectedHeader,
retArray,
DATABASE_NAME);

expectedHeader = new String[] {"time", "s1"};
retArray =
Expand All @@ -190,6 +213,11 @@ public void extractTimeFilterPushDownTest() {
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from ts) + 1< 2",
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from time) <= 0",
expectedHeader,
Expand All @@ -200,6 +228,11 @@ public void extractTimeFilterPushDownTest() {
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from ts) +1<= 1",
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from time) = 0",
expectedHeader,
Expand All @@ -210,6 +243,11 @@ public void extractTimeFilterPushDownTest() {
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from ts) +1= 1",
expectedHeader,
retArray,
DATABASE_NAME);
retArray =
new String[] {
getTimeStrUTC8("2025-07-09T08:17:50") + ",3,",
Expand All @@ -226,6 +264,12 @@ public void extractTimeFilterPushDownTest() {
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from ts) + 1 < 10",
"+08:00",
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from time) <= 8",
"+08:00",
Expand All @@ -238,6 +282,12 @@ public void extractTimeFilterPushDownTest() {
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from ts) +1<= 9",
"+08:00",
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from time) = 8",
"+08:00",
Expand All @@ -250,6 +300,12 @@ public void extractTimeFilterPushDownTest() {
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from ts) +1= 9",
"+08:00",
expectedHeader,
retArray,
DATABASE_NAME);

retArray =
new String[] {
Expand All @@ -266,6 +322,27 @@ public void extractTimeFilterPushDownTest() {
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from ts) +1!= 1",
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from time) between 1 and 2",
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from ts) between 1 and 2",
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from ts) +1 between 2 and 3",
expectedHeader,
retArray,
DATABASE_NAME);

retArray =
new String[] {
getTimeStrUTC8("2025-07-08T09:18:51") + ",1,",
Expand All @@ -283,6 +360,30 @@ public void extractTimeFilterPushDownTest() {
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from ts) +1 != 9",
"+08:00",
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from time) between 9 and 10",
"+08:00",
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from ts) between 9 and 10",
"+08:00",
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
"SELECT time, s1 FROM table1 where extract(hour from ts) +1 between 10 and 11",
"+08:00",
expectedHeader,
retArray,
DATABASE_NAME);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DoubleLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Extract;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.GenericLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.IfExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.InListExpression;
Expand All @@ -47,13 +48,15 @@
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference;
import org.apache.iotdb.db.queryengine.plan.relational.type.InternalTypeManager;

import com.google.common.collect.ImmutableList;
import org.apache.tsfile.common.conf.TSFileConfig;
import org.apache.tsfile.common.regexp.LikePattern;
import org.apache.tsfile.enums.TSDataType;
import org.apache.tsfile.read.common.type.Type;
import org.apache.tsfile.read.filter.basic.Filter;
import org.apache.tsfile.read.filter.factory.FilterFactory;
import org.apache.tsfile.read.filter.factory.ValueFilterApi;
import org.apache.tsfile.read.filter.operator.ExtractTimeFilterOperators;
import org.apache.tsfile.utils.Binary;

import javax.annotation.Nullable;
Expand All @@ -72,18 +75,24 @@
import static org.apache.iotdb.db.queryengine.plan.relational.analyzer.predicate.ConvertPredicateToTimeFilterVisitor.getLongValue;
import static org.apache.iotdb.db.queryengine.plan.relational.analyzer.predicate.PredicatePushIntoScanChecker.isLiteral;
import static org.apache.iotdb.db.queryengine.plan.relational.analyzer.predicate.PredicatePushIntoScanChecker.isSymbolReference;
import static org.apache.iotdb.db.queryengine.plan.relational.planner.ir.GlobalTimePredicateExtractVisitor.isExtractTimeColumn;
import static org.apache.iotdb.db.queryengine.plan.relational.planner.ir.GlobalTimePredicateExtractVisitor.isTimeColumn;
import static org.apache.tsfile.read.common.type.TimestampType.TIMESTAMP;

public class ConvertPredicateToFilterVisitor
extends PredicateVisitor<Filter, ConvertPredicateToFilterVisitor.Context> {

@Nullable private final String timeColumnName;
private final ConvertPredicateToTimeFilterVisitor timeFilterVisitor;
private final ZoneId zoneId;
private final TimeUnit currPrecision;

public ConvertPredicateToFilterVisitor(
@Nullable String timeColumnName, ZoneId zoneId, TimeUnit currPrecision) {
this.timeColumnName = timeColumnName;
this.timeFilterVisitor = new ConvertPredicateToTimeFilterVisitor(zoneId, currPrecision);
this.zoneId = zoneId;
this.currPrecision = currPrecision;
}

@Override
Expand Down Expand Up @@ -168,6 +177,48 @@ public static <T extends Comparable<T>> Filter constructCompareFilter(
}
}

private Filter constructExtractCompareFilter(
ComparisonExpression.Operator operator,
SymbolReference symbolReference,
Extract.Field field,
Literal literal,
Context context) {

if (!context.isMeasurementColumn(symbolReference)) {
throw new IllegalStateException(
String.format("Only support measurement column in filter: %s", symbolReference));
}

int measurementIndex = context.getMeasurementIndex(symbolReference.getName());
long value = getValue(literal, TIMESTAMP);
ExtractTimeFilterOperators.Field field1 =
ExtractTimeFilterOperators.Field.values()[field.ordinal()];

switch (operator) {
case EQUAL:
return ValueFilterApi.extractValueEq(
measurementIndex, value, field1, zoneId, currPrecision);
case NOT_EQUAL:
return ValueFilterApi.extractValueNotEq(
measurementIndex, value, field1, zoneId, currPrecision);
case GREATER_THAN:
return ValueFilterApi.extractValueGt(
measurementIndex, value, field1, zoneId, currPrecision);
case GREATER_THAN_OR_EQUAL:
return ValueFilterApi.extractValueGtEq(
measurementIndex, value, field1, zoneId, currPrecision);
case LESS_THAN:
return ValueFilterApi.extractValueLt(
measurementIndex, value, field1, zoneId, currPrecision);
case LESS_THAN_OR_EQUAL:
return ValueFilterApi.extractValueLtEq(
measurementIndex, value, field1, zoneId, currPrecision);
default:
throw new IllegalArgumentException(
String.format("Unsupported extract comparison operator %s", operator));
}
}

@SuppressWarnings("unchecked")
public static <T extends Comparable<T>> T getValue(Literal value, Type dataType) {
try {
Expand Down Expand Up @@ -273,6 +324,22 @@ && isSymbolReference(right)
&& context.isMeasurementColumn((SymbolReference) right)) {
return constructCompareFilter(
node.getOperator().flip(), (SymbolReference) right, (Literal) left, context);
} else if (context.isExtractMeasurementColumn(left) && isLiteral(right)) {
Extract extract = (Extract) left;
return constructExtractCompareFilter(
node.getOperator(),
(SymbolReference) extract.getExpression(),
extract.getField(),
(Literal) right,
context);
} else if (isLiteral(left) && context.isExtractMeasurementColumn(right)) {
Extract extract = (Extract) right;
return constructExtractCompareFilter(
node.getOperator().flip(),
(SymbolReference) extract.getExpression(),
extract.getField(),
(Literal) left,
context);
} else {
throw new IllegalStateException(
String.format("%s is not supported in value push down", node));
Expand Down Expand Up @@ -307,7 +374,10 @@ protected Filter visitBetweenPredicate(BetweenPredicate node, Context context) {

if (isTimeColumn(firstExpression, timeColumnName)
|| isTimeColumn(secondExpression, timeColumnName)
|| isTimeColumn(thirdExpression, timeColumnName)) {
|| isTimeColumn(thirdExpression, timeColumnName)
|| isExtractTimeColumn(firstExpression, timeColumnName)
|| isExtractTimeColumn(secondExpression, timeColumnName)
|| isExtractTimeColumn(thirdExpression, timeColumnName)) {
return timeFilterVisitor.process(node, null);
}

Expand All @@ -331,6 +401,33 @@ protected Filter visitBetweenPredicate(BetweenPredicate node, Context context) {
(SymbolReference) thirdExpression,
(Literal) firstExpression,
context);
} else if (context.isExtractMeasurementColumn(firstExpression)) {
checkArgument(isLiteral(secondExpression));
checkArgument(isLiteral(thirdExpression));
long minValue = getLongValue(secondExpression);
long maxValue = getLongValue(thirdExpression);
Extract extract = (Extract) firstExpression;
int measurementIndex =
context.getMeasurementIndex(((SymbolReference) extract.getExpression()).getName());
ExtractTimeFilterOperators.Field field =
ExtractTimeFilterOperators.Field.values()[extract.getField().ordinal()];

if (minValue == maxValue) {
return ValueFilterApi.extractValueEq(
measurementIndex, minValue, field, zoneId, currPrecision);
}
return FilterFactory.and(
ImmutableList.of(
ValueFilterApi.extractValueGtEq(
measurementIndex, minValue, field, zoneId, currPrecision),
ValueFilterApi.extractValueLtEq(
measurementIndex, maxValue, field, zoneId, currPrecision)));
} else if (context.isExtractMeasurementColumn(secondExpression)) {
throw new IllegalStateException(
"Should not reach here before PredicateCombineIntoTableScanChecker support Extract push-down in third child");
} else if (context.isExtractMeasurementColumn(thirdExpression)) {
throw new IllegalStateException(
"Should not reach here before PredicateCombineIntoTableScanChecker support Extract push-down in third child");
} else {
throw new IllegalStateException(
String.format("%s is not supported in value push down", node));
Expand Down Expand Up @@ -429,5 +526,14 @@ public boolean isMeasurementColumn(SymbolReference symbolReference) {
ColumnSchema schema = schemaMap.get(Symbol.from(symbolReference));
return schema != null && schema.getColumnCategory() == TsTableColumnCategory.FIELD;
}

public boolean isExtractMeasurementColumn(Expression expression) {
if (expression instanceof Extract
&& ((Extract) expression).getExpression() instanceof SymbolReference) {
ColumnSchema schema = schemaMap.get(Symbol.from(((Extract) expression).getExpression()));
return schema != null && schema.getColumnCategory() == TsTableColumnCategory.FIELD;
}
return false;
}
}
}
Loading
Loading