Skip to content

Commit 40688ad

Browse files
authored
Merge branch 'main' into unified-allocator-followup
2 parents 041dc0e + c2d0a6c commit 40688ad

37 files changed

Lines changed: 2223 additions & 616 deletions

sandbox/libs/analytics-framework/src/main/java/org/opensearch/analytics/spi/AggregateFunction.java

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
package org.opensearch.analytics.spi;
1010

1111
import org.apache.arrow.vector.types.pojo.ArrowType;
12+
import org.apache.calcite.rel.type.RelDataType;
13+
import org.apache.calcite.rel.type.RelDataTypeFactory;
1214
import org.apache.calcite.sql.SqlAggFunction;
1315
import org.apache.calcite.sql.SqlKind;
1416
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
17+
import org.apache.calcite.sql.type.SqlTypeName;
1518

1619
import java.util.List;
1720

@@ -50,10 +53,12 @@ public enum AggregateFunction {
5053
COLLECT(Type.STATE_EXPANDING, SqlKind.COLLECT),
5154
LISTAGG(Type.STATE_EXPANDING, SqlKind.LISTAGG),
5255

53-
// Approximate — probabilistic, fixed-size state. Engine-native merge: null reducer
54-
// means the field is reduced by this same function (APPROX_COUNT_DISTINCT merges
55-
// partial HLL sketches into a final sketch).
56-
APPROX_COUNT_DISTINCT(Type.APPROXIMATE, SqlKind.OTHER, fields(IF("sketch", new ArrowType.Binary(), null)));
56+
APPROX_COUNT_DISTINCT(Type.APPROXIMATE, SqlKind.OTHER, fields(IF("sketch", new ArrowType.Binary(), null))),
57+
TAKE(Type.STATE_EXPANDING, SqlKind.OTHER, fields(IF("take_state", IntermediateTypeResolver.passThroughArg0(), null))),
58+
FIRST(Type.STATE_EXPANDING, SqlKind.OTHER, fields(IF("first_state", IntermediateTypeResolver.passThroughArg0(), null))),
59+
LAST(Type.STATE_EXPANDING, SqlKind.OTHER, fields(IF("last_state", IntermediateTypeResolver.passThroughArg0(), null))),
60+
LIST(Type.STATE_EXPANDING, SqlKind.OTHER, fields(IF("list_state", IntermediateTypeResolver.passThroughArg0(), null))),
61+
VALUES(Type.STATE_EXPANDING, SqlKind.OTHER, fields(IF("values_state", IntermediateTypeResolver.passThroughArg0(), null)));
5762

5863
/** Category of aggregate function. Affects execution strategy (shuffle vs map-reduce). */
5964
public enum Type {
@@ -63,8 +68,71 @@ public enum Type {
6368
APPROXIMATE
6469
}
6570

66-
/** Describes one intermediate field emitted by a partial aggregate. A null reducer means "self" (the owning enum constant). */
67-
public record IntermediateField(String name, ArrowType arrowType, AggregateFunction reducer) {
71+
/**
72+
* Describes one intermediate field emitted by a partial aggregate. A null reducer means
73+
* "self" (the owning enum constant).
74+
*
75+
* <p>The {@code typeResolver} produces the field's Calcite type given the FINAL aggregate
76+
* call's input arg types. For fixed-shape states (HLL sketch is always Binary, COUNT
77+
* counter is always Int64) the resolver ignores its input and returns a constant; for
78+
* input-parameterised states (e.g. {@code take(field, N)}'s buffer is {@code list<field>})
79+
* the resolver derives the shape from arg 0. Construct via
80+
* {@link IntermediateTypeResolver#fixed(ArrowType)} and
81+
* {@link IntermediateTypeResolver#passThroughArg0()}.
82+
*/
83+
public record IntermediateField(String name, IntermediateTypeResolver typeResolver, AggregateFunction reducer) {
84+
}
85+
86+
/**
87+
* Computes the intermediate-field type from an aggregate call's arg types. Implementations
88+
* must be pure: same input → same output. Two flavours:
89+
* <ul>
90+
* <li><b>Fixed</b> ({@link #fixed(ArrowType)}) — returns a constant Arrow type wrapped
91+
* in the corresponding Calcite type. Used for COUNT (Int64) and APPROX_COUNT_DISTINCT
92+
* (Binary).</li>
93+
* <li><b>Input-parameterised</b> (custom impls like {@link #passThroughArg0()}) — derives
94+
* the type from {@code argTypes}. Used for {@code take}/{@code list}/{@code values}
95+
* whose state shape is {@code list<arg0>}.</li>
96+
* </ul>
97+
*/
98+
@FunctionalInterface
99+
public interface IntermediateTypeResolver {
100+
/** Resolve the intermediate field's Calcite type. */
101+
RelDataType resolve(List<RelDataType> argTypes, RelDataTypeFactory typeFactory);
102+
103+
/** Resolver that always returns the same Arrow type, irrespective of arg types. */
104+
static IntermediateTypeResolver fixed(ArrowType arrowType) {
105+
return (argTypes, typeFactory) -> ArrowToCalciteTypeMapper.toCalcite(arrowType, typeFactory);
106+
}
107+
108+
/**
109+
* Pass arg 0's type through unchanged. Used by state-expanding aggregates whose
110+
* FINAL re-aggregates over PARTIAL's output column — the column type already
111+
* equals the desired exchange shape.
112+
*/
113+
static IntermediateTypeResolver passThroughArg0() {
114+
return (argTypes, typeFactory) -> {
115+
if (argTypes.isEmpty()) {
116+
throw new IllegalStateException("passThroughArg0 resolver requires at least one arg type");
117+
}
118+
return argTypes.get(0);
119+
};
120+
}
121+
}
122+
123+
/**
124+
* Internal Arrow → Calcite mapper used by {@link IntermediateTypeResolver#fixed}. Lives
125+
* in the SPI module so {@code IntermediateField} stays self-contained — no dependency
126+
* on the planner module. Add cases as new fixed-state shapes appear.
127+
*/
128+
private static final class ArrowToCalciteTypeMapper {
129+
static RelDataType toCalcite(ArrowType t, RelDataTypeFactory f) {
130+
return switch (t) {
131+
case ArrowType.Int i when i.getBitWidth() == 64 -> f.createSqlType(SqlTypeName.BIGINT);
132+
case ArrowType.Binary b -> f.createSqlType(SqlTypeName.VARBINARY, Integer.MAX_VALUE);
133+
default -> throw new IllegalArgumentException("Unsupported fixed Arrow type for IntermediateField: " + t);
134+
};
135+
}
68136
}
69137

70138
private final Type type;
@@ -93,7 +161,7 @@ public SqlKind getSqlKind() {
93161
public List<IntermediateField> intermediateFields() {
94162
if (intermediateFields == null) return null;
95163
return intermediateFields.stream()
96-
.map(f -> f.reducer() == null ? new IntermediateField(f.name(), f.arrowType(), this) : f)
164+
.map(f -> f.reducer() == null ? new IntermediateField(f.name(), f.typeResolver(), this) : f)
97165
.toList();
98166
}
99167

@@ -163,6 +231,10 @@ private static List<IntermediateField> fields(IntermediateField... fs) {
163231
}
164232

165233
private static IntermediateField IF(String name, ArrowType arrowType, AggregateFunction reducer) {
166-
return new IntermediateField(name, arrowType, reducer);
234+
return new IntermediateField(name, IntermediateTypeResolver.fixed(arrowType), reducer);
235+
}
236+
237+
private static IntermediateField IF(String name, IntermediateTypeResolver typeResolver, AggregateFunction reducer) {
238+
return new IntermediateField(name, typeResolver, reducer);
167239
}
168240
}

sandbox/libs/analytics-framework/src/test/java/org/opensearch/analytics/spi/AggregateFunctionTests.java

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,26 @@
88

99
package org.opensearch.analytics.spi;
1010

11-
import org.apache.arrow.vector.types.pojo.ArrowType;
11+
import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
12+
import org.apache.calcite.rel.type.RelDataType;
13+
import org.apache.calcite.rel.type.RelDataTypeFactory;
1214
import org.apache.calcite.sql.SqlKind;
15+
import org.apache.calcite.sql.type.SqlTypeName;
1316
import org.opensearch.test.OpenSearchTestCase;
1417

1518
import java.util.List;
1619

1720
import static org.opensearch.analytics.spi.AggregateFunction.APPROX_COUNT_DISTINCT;
1821
import static org.opensearch.analytics.spi.AggregateFunction.AVG;
1922
import static org.opensearch.analytics.spi.AggregateFunction.COUNT;
23+
import static org.opensearch.analytics.spi.AggregateFunction.FIRST;
24+
import static org.opensearch.analytics.spi.AggregateFunction.LAST;
25+
import static org.opensearch.analytics.spi.AggregateFunction.LIST;
2026
import static org.opensearch.analytics.spi.AggregateFunction.MAX;
2127
import static org.opensearch.analytics.spi.AggregateFunction.MIN;
2228
import static org.opensearch.analytics.spi.AggregateFunction.SUM;
29+
import static org.opensearch.analytics.spi.AggregateFunction.TAKE;
30+
import static org.opensearch.analytics.spi.AggregateFunction.VALUES;
2331

2432
/**
2533
* Asserts the enum carries the right shape per function for the resolver's three
@@ -34,6 +42,13 @@
3442
*/
3543
public class AggregateFunctionTests extends OpenSearchTestCase {
3644

45+
private final RelDataTypeFactory typeFactory = new JavaTypeFactoryImpl();
46+
private final RelDataType integer = typeFactory.createSqlType(SqlTypeName.INTEGER);
47+
48+
private RelDataType resolve(AggregateFunction.IntermediateField field, RelDataType arg0) {
49+
return field.typeResolver().resolve(List.of(arg0), typeFactory);
50+
}
51+
3752
// ── Pass-through: SUM / MIN / MAX ──
3853

3954
public void testSumHasNoDecomposition() {
@@ -52,8 +67,7 @@ public void testCountIntermediateFields() {
5267
assertEquals(1, fields.size());
5368
assertEquals("count", fields.get(0).name());
5469
assertSame(SUM, fields.get(0).reducer());
55-
assertTrue(fields.get(0).arrowType() instanceof ArrowType.Int);
56-
assertEquals(64, ((ArrowType.Int) fields.get(0).arrowType()).getBitWidth());
70+
assertEquals(SqlTypeName.BIGINT, resolve(fields.get(0), integer).getSqlTypeName());
5771
}
5872

5973
// ── AVG / STDDEV / VAR: handled by Calcite's reduce rule — no enum metadata ──
@@ -77,7 +91,77 @@ public void testApproxCountDistinctReducerIsSelf() {
7791
assertEquals(1, fields.size());
7892
assertEquals("sketch", fields.get(0).name());
7993
assertSame(APPROX_COUNT_DISTINCT, fields.get(0).reducer());
80-
assertTrue(fields.get(0).arrowType() instanceof ArrowType.Binary);
94+
assertEquals(SqlTypeName.VARBINARY, resolve(fields.get(0), integer).getSqlTypeName());
95+
}
96+
97+
// ── TAKE: engine-native (single field, reducer == self, parameterised resolver) ──
98+
99+
public void testTakeHasDecomposition() {
100+
assertTrue(TAKE.hasDecomposition());
101+
}
102+
103+
public void testTakeReducerIsSelfAndResolverIsParameterised() {
104+
List<AggregateFunction.IntermediateField> fields = TAKE.intermediateFields();
105+
assertEquals(1, fields.size());
106+
assertEquals("take_state", fields.get(0).name());
107+
assertSame(TAKE, fields.get(0).reducer());
108+
assertEquals("passThroughArg0 returns arg0", integer, resolve(fields.get(0), integer));
109+
}
110+
111+
// ── FIRST: engine-native (single field, reducer == self, parameterised resolver) ──
112+
113+
public void testFirstHasDecomposition() {
114+
assertTrue(FIRST.hasDecomposition());
115+
}
116+
117+
public void testFirstReducerIsSelf() {
118+
List<AggregateFunction.IntermediateField> fields = FIRST.intermediateFields();
119+
assertEquals(1, fields.size());
120+
assertEquals("first_state", fields.get(0).name());
121+
assertSame(FIRST, fields.get(0).reducer());
122+
assertEquals(integer, resolve(fields.get(0), integer));
123+
}
124+
125+
// ── LAST: engine-native (single field, reducer == self, parameterised resolver) ──
126+
127+
public void testLastHasDecomposition() {
128+
assertTrue(LAST.hasDecomposition());
129+
}
130+
131+
public void testLastReducerIsSelf() {
132+
List<AggregateFunction.IntermediateField> fields = LAST.intermediateFields();
133+
assertEquals(1, fields.size());
134+
assertEquals("last_state", fields.get(0).name());
135+
assertSame(LAST, fields.get(0).reducer());
136+
assertEquals(integer, resolve(fields.get(0), integer));
137+
}
138+
139+
// ── LIST: engine-native (single field, reducer == self, parameterised resolver) ──
140+
141+
public void testListHasDecomposition() {
142+
assertTrue(LIST.hasDecomposition());
143+
}
144+
145+
public void testListReducerIsSelf() {
146+
List<AggregateFunction.IntermediateField> fields = LIST.intermediateFields();
147+
assertEquals(1, fields.size());
148+
assertEquals("list_state", fields.get(0).name());
149+
assertSame(LIST, fields.get(0).reducer());
150+
assertEquals(integer, resolve(fields.get(0), integer));
151+
}
152+
153+
// ── VALUES: engine-native (single field, reducer == self, parameterised resolver) ──
154+
155+
public void testValuesHasDecomposition() {
156+
assertTrue(VALUES.hasDecomposition());
157+
}
158+
159+
public void testValuesReducerIsSelf() {
160+
List<AggregateFunction.IntermediateField> fields = VALUES.intermediateFields();
161+
assertEquals(1, fields.size());
162+
assertEquals("values_state", fields.get(0).name());
163+
assertSame(VALUES, fields.get(0).reducer());
164+
assertEquals(integer, resolve(fields.get(0), integer));
81165
}
82166

83167
// ── fromSqlKind still works ──

sandbox/plugins/analytics-backend-datafusion/rust/src/api.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,7 @@ pub unsafe fn sql_to_substrait(
534534
.build();
535535
let ctx = datafusion::prelude::SessionContext::new_with_state(state);
536536
crate::udf::register_all(&ctx);
537+
crate::udaf::register_all(&ctx);
537538

538539
let listing_options = ListingOptions::new(Arc::new(ParquetFormat::new()))
539540
.with_file_extension(".parquet")
@@ -585,6 +586,7 @@ fn derive_schema_from_partial_plan(
585586
.build();
586587
let ctx = SessionContext::new_with_state(state);
587588
crate::udf::register_all(&ctx);
589+
crate::udaf::register_all(&ctx);
588590

589591
let extensions = Extensions::default();
590592
let session_state = ctx.state();

sandbox/plugins/analytics-backend-datafusion/rust/src/indexed_executor.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ pub async fn execute_indexed_query(
148148
ctx.register_udf(create_index_filter_udf());
149149
ctx.register_udf(crate::indexed_table::substrait_to_tree::create_delegation_possible_udf());
150150
crate::udf::register_all(&ctx);
151+
crate::udaf::register_all(&ctx);
151152

152153
// Register default ListingTable so substrait consumer can resolve the table
153154
let listing_options = datafusion::datasource::listing::ListingOptions::new(
@@ -450,6 +451,7 @@ pub async unsafe fn execute_indexed_with_context(
450451
let (segments, schema) = build_segments(&state, Arc::clone(&store), object_metas.as_ref(), writer_generations.as_ref())
451452
.await
452453
.map_err(DataFusionError::Execution)?;
454+
let schema = crate::schema_coerce::coerce_inferred_schema(schema);
453455
for (i, seg) in segments.iter().enumerate() {
454456
}
455457

sandbox/plugins/analytics-backend-datafusion/rust/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ pub mod runtime_manager;
3333
pub mod schema_coerce;
3434
pub mod session_context;
3535
pub mod statistics_cache;
36+
pub mod udaf;
3637
pub mod udf;
3738
pub mod stats;
3839
pub mod task_monitors;

sandbox/plugins/analytics-backend-datafusion/rust/src/local_executor.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ impl LocalSession {
7474
.build();
7575
let ctx = SessionContext::new_with_state(state);
7676
crate::udf::register_all(&ctx);
77+
crate::udaf::register_all(&ctx);
7778
Self { ctx, prepared_plan: None }
7879
}
7980

sandbox/plugins/analytics-backend-datafusion/rust/src/query_executor.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ pub async fn execute_query(
113113

114114
let ctx = SessionContext::new_with_state(state);
115115
crate::udf::register_all(&ctx);
116+
crate::udaf::register_all(&ctx);
116117

117118
// Register table via ListingTable — all IO goes through object store
118119
let file_format = ParquetFormat::new();

sandbox/plugins/analytics-backend-datafusion/rust/src/session_context.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ pub async unsafe fn create_session_context(
133133
// Without this, fragment execution fails with "Unsupported function name" because
134134
// df_execute_with_context reuses this handle's ctx instead of building a fresh one.
135135
crate::udf::register_all(&ctx);
136+
crate::udaf::register_all(&ctx);
136137

137138
// Register default ListingTable for parquet scans.
138139
let listing_options = ListingOptions::new(Arc::new(ParquetFormat::default()))

0 commit comments

Comments
 (0)