Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
import org.opensearch.sql.ast.tree.TableFunction;
import org.opensearch.sql.ast.tree.Transpose;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.ast.tree.Union;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.ast.tree.Values;
import org.opensearch.sql.ast.tree.Window;
Expand Down Expand Up @@ -897,6 +898,11 @@ public LogicalPlan visitMultisearch(Multisearch node, AnalysisContext context) {
throw getOnlyForCalciteException("Multisearch");
}

@Override
public LogicalPlan visitUnion(Union node, AnalysisContext context) {
throw getOnlyForCalciteException("Union");
}

private LogicalSort buildSort(
LogicalPlan child, AnalysisContext context, Integer count, List<Field> sortFields) {
ExpressionReferenceOptimizer optimizer =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
import org.opensearch.sql.ast.tree.TableFunction;
import org.opensearch.sql.ast.tree.Transpose;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.ast.tree.Union;
import org.opensearch.sql.ast.tree.Values;
import org.opensearch.sql.ast.tree.Window;

Expand Down Expand Up @@ -472,6 +473,10 @@ public T visitMultisearch(Multisearch node, C context) {
return visitChildren(node, context);
}

public T visitUnion(Union node, C context) {
return visitChildren(node, context);
}

public T visitAddTotals(AddTotals node, C context) {
return visitChildren(node, context);
}
Expand Down
44 changes: 44 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/Union.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.tree;

import com.google.common.collect.ImmutableList;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;

/** Logical plan node for Union operation. Combines results from multiple datasets (UNION ALL). */
@Getter
@ToString
@EqualsAndHashCode(callSuper = false)
@RequiredArgsConstructor
@AllArgsConstructor
public class Union extends UnresolvedPlan {
private final List<UnresolvedPlan> datasets;

private Integer maxout;

@Override
public UnresolvedPlan attach(UnresolvedPlan child) {
List<UnresolvedPlan> newDatasets =
ImmutableList.<UnresolvedPlan>builder().add(child).addAll(datasets).build();
return new Union(newDatasets, maxout);
}

@Override
public List<? extends UnresolvedPlan> getChild() {
return datasets;
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitUnion(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
import org.opensearch.sql.ast.tree.TableFunction;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.ast.tree.Trendline.TrendlineType;
import org.opensearch.sql.ast.tree.Union;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.ast.tree.Values;
import org.opensearch.sql.ast.tree.Window;
Expand Down Expand Up @@ -2627,6 +2628,40 @@ private String findTimestampField(RelDataType rowType) {
return null;
}

@Override
public RelNode visitUnion(Union node, CalcitePlanContext context) {
List<RelNode> inputNodes = new ArrayList<>();

for (UnresolvedPlan dataset : node.getDatasets()) {
UnresolvedPlan prunedDataset = dataset.accept(new EmptySourcePropagateVisitor(), null);
prunedDataset.accept(this, context);
inputNodes.add(context.relBuilder.build());
}

if (inputNodes.size() < 2) {
throw new IllegalArgumentException(
"Union command requires at least two datasets. Provided: " + inputNodes.size());
}

List<RelNode> unifiedInputs =
SchemaUnifier.buildUnifiedSchemaWithTypeCoercion(inputNodes, context);

for (RelNode input : unifiedInputs) {
context.relBuilder.push(input);
}
context.relBuilder.union(true, unifiedInputs.size()); // true = UNION ALL

if (node.getMaxout() != null) {
Comment thread
srikanthpadakanti marked this conversation as resolved.
context.relBuilder.push(
LogicalSystemLimit.create(
LogicalSystemLimit.SystemLimitType.SUBSEARCH_MAXOUT,
context.relBuilder.build(),
context.relBuilder.literal(node.getMaxout())));
}

return context.relBuilder.peek();
}

/*
* Unsupported Commands of PPL with Calcite for OpenSearch 3.0.0-beta
*/
Expand Down
242 changes: 240 additions & 2 deletions core/src/main/java/org/opensearch/sql/calcite/SchemaUnifier.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,16 @@
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.type.SqlTypeName;

/**
* Utility class for unifying schemas across multiple RelNodes. Throws an exception when type
* conflicts are detected.
* Utility class for unifying schemas across multiple RelNodes. Supports two strategies:
*
* <ul>
* <li>Conflict resolution (multisearch): throws on type mismatch, fills missing fields with NULL
* <li>Type coercion (union): widens compatible types (e.g. INTEGER→BIGINT), falls back to VARCHAR
* for incompatible types, fills missing fields with NULL
* </ul>
*/
public class SchemaUnifier {

Expand Down Expand Up @@ -147,4 +153,236 @@ RelDataType getType() {
return type;
}
}

/**
* Builds unified schema with type coercion for UNION command. Coerces compatible types to a
* common supertype (e.g. int+float→float), falls back to VARCHAR for incompatible types, and
* fills missing fields with NULL.
*/
public static List<RelNode> buildUnifiedSchemaWithTypeCoercion(
List<RelNode> inputs, CalcitePlanContext context) {
if (inputs.isEmpty() || inputs.size() == 1) {
return inputs;
}

List<RelNode> coercedInputs = coerceUnionTypes(inputs, context);
return unifySchemasForUnion(coercedInputs, context);
}

/**
* Aligns schemas by projecting NULL for missing fields and CAST for type mismatches. Uses
* force=true to clear collation traits and prevent EnumerableMergeUnion cast exception.
*/
private static List<RelNode> unifySchemasForUnion(
List<RelNode> inputs, CalcitePlanContext context) {
List<SchemaField> unifiedSchema = buildUnifiedSchemaForUnion(inputs);
List<String> fieldNames =
unifiedSchema.stream().map(SchemaField::getName).collect(Collectors.toList());

List<RelNode> projectedNodes = new ArrayList<>();
for (RelNode node : inputs) {
List<RexNode> projection = buildProjectionForUnion(node, unifiedSchema, context);
RelNode projectedNode =
context.relBuilder.push(node).project(projection, fieldNames, true).build();
projectedNodes.add(projectedNode);
}
return projectedNodes;
}

private static List<SchemaField> buildUnifiedSchemaForUnion(List<RelNode> nodes) {
List<SchemaField> schema = new ArrayList<>();
Map<String, RelDataType> seenFields = new HashMap<>();

for (RelNode node : nodes) {
for (RelDataTypeField field : node.getRowType().getFieldList()) {
if (!seenFields.containsKey(field.getName())) {
schema.add(new SchemaField(field.getName(), field.getType()));
seenFields.put(field.getName(), field.getType());
Comment thread
srikanthpadakanti marked this conversation as resolved.
}
}
}
return schema;
}

private static List<RexNode> buildProjectionForUnion(
RelNode node, List<SchemaField> unifiedSchema, CalcitePlanContext context) {
Map<String, RelDataTypeField> nodeFieldMap =
node.getRowType().getFieldList().stream()
.collect(Collectors.toMap(RelDataTypeField::getName, field -> field));

List<RexNode> projection = new ArrayList<>();
for (SchemaField schemaField : unifiedSchema) {
RelDataTypeField nodeField = nodeFieldMap.get(schemaField.getName());

if (nodeField != null) {
RexNode fieldRef = context.rexBuilder.makeInputRef(node, nodeField.getIndex());
if (!nodeField.getType().equals(schemaField.getType())) {
projection.add(context.rexBuilder.makeCast(schemaField.getType(), fieldRef));
} else {
projection.add(fieldRef);
}
} else {
projection.add(context.rexBuilder.makeNullLiteral(schemaField.getType()));
}
}
return projection;
}

/** Casts fields to their common supertypes across all inputs when types differ. */
private static List<RelNode> coerceUnionTypes(List<RelNode> inputs, CalcitePlanContext context) {
Map<String, List<SqlTypeName>> fieldTypeMap = new HashMap<>();
for (RelNode input : inputs) {
for (RelDataTypeField field : input.getRowType().getFieldList()) {
String fieldName = field.getName();
SqlTypeName typeName = field.getType().getSqlTypeName();
if (typeName != null) {
fieldTypeMap.computeIfAbsent(fieldName, k -> new ArrayList<>()).add(typeName);
}
}
}

Map<String, SqlTypeName> targetTypeMap = new HashMap<>();
for (Map.Entry<String, List<SqlTypeName>> entry : fieldTypeMap.entrySet()) {
String fieldName = entry.getKey();
List<SqlTypeName> types = entry.getValue();

SqlTypeName commonType = types.getFirst();
for (int i = 1; i < types.size(); i++) {
commonType = findCommonTypeForUnion(commonType, types.get(i));
}
targetTypeMap.put(fieldName, commonType);
}

boolean needsCoercion = false;
for (RelNode input : inputs) {
for (RelDataTypeField field : input.getRowType().getFieldList()) {
SqlTypeName targetType = targetTypeMap.get(field.getName());
if (targetType != null && field.getType().getSqlTypeName() != targetType) {
needsCoercion = true;
break;
}
}
if (needsCoercion) break;
}

if (!needsCoercion) {
return inputs;
}

List<RelNode> coercedInputs = new ArrayList<>();
for (RelNode input : inputs) {
List<RexNode> projections = new ArrayList<>();
List<String> projectionNames = new ArrayList<>();
boolean needsProjection = false;

for (RelDataTypeField field : input.getRowType().getFieldList()) {
String fieldName = field.getName();
SqlTypeName currentType = field.getType().getSqlTypeName();
SqlTypeName targetType = targetTypeMap.get(fieldName);

RexNode fieldRef = context.rexBuilder.makeInputRef(input, field.getIndex());

if (currentType != targetType && targetType != null) {
projections.add(context.relBuilder.cast(fieldRef, targetType));
needsProjection = true;
} else {
projections.add(fieldRef);
}
projectionNames.add(fieldName);
}

if (needsProjection) {
context.relBuilder.push(input);
context.relBuilder.project(projections, projectionNames, true);
coercedInputs.add(context.relBuilder.build());
} else {
coercedInputs.add(input);
}
}

return coercedInputs;
}

/**
* Returns the wider type for two SqlTypeNames. Within the same family, returns the wider type
* (e.g. INTEGER+BIGINT-->BIGINT). Across families, falls back to VARCHAR.
*/
private static SqlTypeName findCommonTypeForUnion(SqlTypeName type1, SqlTypeName type2) {
if (type1 == type2) {
return type1;
}

if (type1 == SqlTypeName.NULL) {
return type2;
}
if (type2 == SqlTypeName.NULL) {
return type1;
}

if (isNumericTypeForUnion(type1) && isNumericTypeForUnion(type2)) {
return getWiderNumericTypeForUnion(type1, type2);
}

if (isStringTypeForUnion(type1) && isStringTypeForUnion(type2)) {
return SqlTypeName.VARCHAR;
}

if (isTemporalTypeForUnion(type1) && isTemporalTypeForUnion(type2)) {
return getWiderTemporalTypeForUnion(type1, type2);
}

return SqlTypeName.VARCHAR;
}

private static boolean isNumericTypeForUnion(SqlTypeName typeName) {
return typeName == SqlTypeName.TINYINT
|| typeName == SqlTypeName.SMALLINT
|| typeName == SqlTypeName.INTEGER
|| typeName == SqlTypeName.BIGINT
|| typeName == SqlTypeName.FLOAT
|| typeName == SqlTypeName.REAL
|| typeName == SqlTypeName.DOUBLE
|| typeName == SqlTypeName.DECIMAL;
}

private static boolean isStringTypeForUnion(SqlTypeName typeName) {
return typeName == SqlTypeName.CHAR || typeName == SqlTypeName.VARCHAR;
}

private static boolean isTemporalTypeForUnion(SqlTypeName typeName) {
return typeName == SqlTypeName.DATE
|| typeName == SqlTypeName.TIMESTAMP
|| typeName == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE;
}

private static SqlTypeName getWiderNumericTypeForUnion(SqlTypeName type1, SqlTypeName type2) {
int rank1 = getNumericTypeRankForUnion(type1);
int rank2 = getNumericTypeRankForUnion(type2);
return rank1 >= rank2 ? type1 : type2;
}

private static int getNumericTypeRankForUnion(SqlTypeName typeName) {
return switch (typeName) {
case TINYINT -> 1;
case SMALLINT -> 2;
case INTEGER -> 3;
case BIGINT -> 4;
case DECIMAL -> 5;
case REAL -> 6;
case FLOAT -> 7;
case DOUBLE -> 8;
default -> 0;
};
}

private static SqlTypeName getWiderTemporalTypeForUnion(SqlTypeName type1, SqlTypeName type2) {
if (type1 == SqlTypeName.TIMESTAMP || type2 == SqlTypeName.TIMESTAMP) {
Comment thread
srikanthpadakanti marked this conversation as resolved.
return SqlTypeName.TIMESTAMP;
}
if (type1 == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE
|| type2 == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE) {
return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE;
}
return SqlTypeName.DATE;
}
}
1 change: 1 addition & 0 deletions docs/category.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"user/ppl/cmd/top.md",
"user/ppl/cmd/trendline.md",
"user/ppl/cmd/transpose.md",
"user/ppl/cmd/union.md",
"user/ppl/cmd/where.md",
"user/ppl/functions/aggregations.md",
"user/ppl/functions/collection.md",
Expand Down
Loading
Loading