|
| 1 | +package org.hypertrace.core.documentstore.mongo.query; |
| 2 | + |
| 3 | +import static java.util.Collections.singleton; |
| 4 | +import static java.util.function.Function.identity; |
| 5 | +import static java.util.function.Predicate.not; |
| 6 | +import static java.util.stream.Collectors.toUnmodifiableList; |
| 7 | +import static org.hypertrace.core.documentstore.mongo.query.MongoPaginationHelper.getLimitClause; |
| 8 | +import static org.hypertrace.core.documentstore.mongo.query.MongoPaginationHelper.getSkipClause; |
| 9 | +import static org.hypertrace.core.documentstore.mongo.query.parser.MongoFilterTypeExpressionParser.getFilterClause; |
| 10 | +import static org.hypertrace.core.documentstore.mongo.query.parser.MongoNonProjectedSortTypeExpressionParser.getNonProjectedSortClause; |
| 11 | +import static org.hypertrace.core.documentstore.mongo.query.parser.MongoSelectTypeExpressionParser.getProjectClause; |
| 12 | +import static org.hypertrace.core.documentstore.mongo.query.parser.MongoSortTypeExpressionParser.getSortClause; |
| 13 | + |
| 14 | +import com.mongodb.BasicDBObject; |
| 15 | +import com.mongodb.client.MongoCollection; |
| 16 | +import java.util.Collection; |
| 17 | +import java.util.List; |
| 18 | +import java.util.Map; |
| 19 | +import java.util.Optional; |
| 20 | +import java.util.function.Function; |
| 21 | +import java.util.stream.Collectors; |
| 22 | +import lombok.AllArgsConstructor; |
| 23 | +import lombok.extern.slf4j.Slf4j; |
| 24 | +import org.hypertrace.core.documentstore.model.config.AggregatePipelineMode; |
| 25 | +import org.hypertrace.core.documentstore.mongo.query.parser.AliasParser; |
| 26 | +import org.hypertrace.core.documentstore.mongo.query.parser.MongoFromTypeExpressionParser; |
| 27 | +import org.hypertrace.core.documentstore.mongo.query.parser.MongoGroupTypeExpressionParser; |
| 28 | +import org.hypertrace.core.documentstore.parser.AggregateExpressionChecker; |
| 29 | +import org.hypertrace.core.documentstore.parser.FunctionExpressionChecker; |
| 30 | +import org.hypertrace.core.documentstore.query.Query; |
| 31 | +import org.hypertrace.core.documentstore.query.SelectionSpec; |
| 32 | +import org.hypertrace.core.documentstore.query.SortingSpec; |
| 33 | + |
| 34 | +@Slf4j |
| 35 | +@AllArgsConstructor |
| 36 | +public class MongoAggregationPipelineConverter { |
| 37 | + private final AggregatePipelineMode aggregationPipelineMode; |
| 38 | + private final MongoCollection<BasicDBObject> collection; |
| 39 | + |
| 40 | + private final List<Function<Query, Collection<BasicDBObject>>> |
| 41 | + DEFAULT_AGGREGATE_PIPELINE_FUNCTIONS = |
| 42 | + List.of( |
| 43 | + query -> singleton(getFilterClause(query, Query::getFilter)), |
| 44 | + query -> new MongoFromTypeExpressionParser(this).getFromClauses(query), |
| 45 | + MongoGroupTypeExpressionParser::getGroupClauses, |
| 46 | + query -> singleton(getProjectClause(query)), |
| 47 | + query -> singleton(getFilterClause(query, Query::getAggregationFilter)), |
| 48 | + query -> singleton(getSortClause(query)), |
| 49 | + query -> singleton(getSkipClause(query)), |
| 50 | + query -> singleton(getLimitClause(query))); |
| 51 | + |
| 52 | + private final List<Function<Query, Collection<BasicDBObject>>> |
| 53 | + SORT_OPTIMISED_AGGREGATE_PIPELINE_FUNCTIONS = |
| 54 | + List.of( |
| 55 | + query -> singleton(getFilterClause(query, Query::getFilter)), |
| 56 | + query -> new MongoFromTypeExpressionParser(this).getFromClauses(query), |
| 57 | + query -> singleton(getNonProjectedSortClause(query)), |
| 58 | + query -> singleton(getSkipClause(query)), |
| 59 | + query -> singleton(getLimitClause(query)), |
| 60 | + query -> singleton(getProjectClause(query))); |
| 61 | + |
| 62 | + public String getCollectionName() { |
| 63 | + return collection.getNamespace().getCollectionName(); |
| 64 | + } |
| 65 | + |
| 66 | + public List<BasicDBObject> convertToAggregatePipeline(Query query) { |
| 67 | + List<Function<Query, Collection<BasicDBObject>>> aggregatePipeline = |
| 68 | + getAggregationPipeline(query); |
| 69 | + |
| 70 | + List<BasicDBObject> pipeline = |
| 71 | + aggregatePipeline.stream() |
| 72 | + .flatMap(function -> function.apply(query).stream()) |
| 73 | + .filter(not(BasicDBObject::isEmpty)) |
| 74 | + .collect(toUnmodifiableList()); |
| 75 | + return pipeline; |
| 76 | + } |
| 77 | + |
| 78 | + private List<Function<Query, Collection<BasicDBObject>>> getAggregationPipeline(Query query) { |
| 79 | + List<Function<Query, Collection<BasicDBObject>>> aggregatePipeline = |
| 80 | + DEFAULT_AGGREGATE_PIPELINE_FUNCTIONS; |
| 81 | + if (aggregationPipelineMode.equals(AggregatePipelineMode.SORT_OPTIMIZED_IF_POSSIBLE) |
| 82 | + && query.getAggregations().isEmpty() |
| 83 | + && query.getAggregationFilter().isEmpty() |
| 84 | + && !isProjectionContainsAggregation(query) |
| 85 | + && !isSortContainsAggregation(query)) { |
| 86 | + log.debug("Using sort optimized aggregate pipeline functions for query: {}", query); |
| 87 | + aggregatePipeline = SORT_OPTIMISED_AGGREGATE_PIPELINE_FUNCTIONS; |
| 88 | + } |
| 89 | + return aggregatePipeline; |
| 90 | + } |
| 91 | + |
| 92 | + private boolean isProjectionContainsAggregation(Query query) { |
| 93 | + return query.getSelections().stream() |
| 94 | + .map(SelectionSpec::getExpression) |
| 95 | + .anyMatch(spec -> spec.accept(new AggregateExpressionChecker())); |
| 96 | + } |
| 97 | + |
| 98 | + private boolean isSortContainsAggregation(Query query) { |
| 99 | + // ideally there should be only one alias per selection, |
| 100 | + // in case of duplicates, we will accept the latest one |
| 101 | + Map<String, SelectionSpec> aliasToSelectionMap = |
| 102 | + query.getSelections().stream() |
| 103 | + .filter(spec -> this.getAlias(spec).isPresent()) |
| 104 | + .collect( |
| 105 | + Collectors.toMap( |
| 106 | + entry -> this.getAlias(entry).orElseThrow(), identity(), (v1, v2) -> v2)); |
| 107 | + return query.getSorts().stream() |
| 108 | + .anyMatch(sort -> isSortOnAggregatedField(aliasToSelectionMap, sort)); |
| 109 | + } |
| 110 | + |
| 111 | + private boolean isSortOnAggregatedField( |
| 112 | + Map<String, SelectionSpec> aliasToSelectionMap, SortingSpec sort) { |
| 113 | + boolean isFunctionExpression = sort.getExpression().accept(new FunctionExpressionChecker()); |
| 114 | + boolean isAggregateExpression = sort.getExpression().accept(new AggregateExpressionChecker()); |
| 115 | + return isFunctionExpression |
| 116 | + || isAggregateExpression |
| 117 | + || isSortOnAggregatedProjection(aliasToSelectionMap, sort); |
| 118 | + } |
| 119 | + |
| 120 | + private Optional<String> getAlias(SelectionSpec selectionSpec) { |
| 121 | + if (selectionSpec.getAlias() != null) { |
| 122 | + return Optional.of(selectionSpec.getAlias()); |
| 123 | + } |
| 124 | + |
| 125 | + return selectionSpec.getExpression().accept(new AliasParser()); |
| 126 | + } |
| 127 | + |
| 128 | + private boolean isSortOnAggregatedProjection( |
| 129 | + Map<String, SelectionSpec> aliasToSelectionMap, SortingSpec sort) { |
| 130 | + Optional<String> alias = sort.getExpression().accept(new AliasParser()); |
| 131 | + if (alias.isEmpty()) { |
| 132 | + throw new UnsupportedOperationException( |
| 133 | + "Cannot sort by an expression that does not have an alias in selection"); |
| 134 | + } |
| 135 | + |
| 136 | + SelectionSpec selectionSpec = aliasToSelectionMap.get(alias.get()); |
| 137 | + if (selectionSpec == null) { |
| 138 | + return false; |
| 139 | + } |
| 140 | + |
| 141 | + Boolean isFunctionExpression = |
| 142 | + selectionSpec.getExpression().accept(new FunctionExpressionChecker()); |
| 143 | + Boolean isAggregationExpression = |
| 144 | + selectionSpec.getExpression().accept(new AggregateExpressionChecker()); |
| 145 | + return isFunctionExpression || isAggregationExpression; |
| 146 | + } |
| 147 | +} |
0 commit comments