Skip to content

Commit 5bb8447

Browse files
committed
Preserve original logical type in Beam->Calcite->Beam trip
1 parent a8fa6db commit 5bb8447

6 files changed

Lines changed: 45 additions & 26 deletions

File tree

sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.lang.reflect.Type;
2323
import java.util.Date;
2424
import java.util.Map;
25+
import java.util.concurrent.ConcurrentHashMap;
2526
import java.util.stream.IntStream;
2627
import org.apache.beam.sdk.schemas.Schema;
2728
import org.apache.beam.sdk.schemas.Schema.FieldType;
@@ -169,6 +170,12 @@ public static boolean isStringType(FieldType fieldType) {
169170
FieldType.DATETIME, SqlTypeName.TIMESTAMP,
170171
FieldType.STRING, SqlTypeName.VARCHAR);
171172

173+
// Associating FieldType to generated RelDataType objects for Beam logical types. Used for
174+
// recovering the original type in output schema after full Beam FieldType->Calcite Type->Beam
175+
// FieldType trip
176+
private static final Map<RelDataType, FieldType> LOGICAL_TYPE_REL_DATA_MAPPING =
177+
new ConcurrentHashMap<>();
178+
172179
/** Generate {@link Schema} from {@code RelDataType} which is used to create table. */
173180
public static Schema toSchema(RelDataType tableInfo) {
174181
return tableInfo.getFieldList().stream().map(CalciteUtils::toField).collect(Schema.toSchema());
@@ -254,6 +261,9 @@ public static Schema.Field toField(String name, RelDataType calciteType) {
254261
}
255262

256263
public static FieldType toFieldType(RelDataType calciteType) {
264+
if (LOGICAL_TYPE_REL_DATA_MAPPING.containsKey(calciteType)) {
265+
return LOGICAL_TYPE_REL_DATA_MAPPING.get(calciteType);
266+
}
257267
switch (calciteType.getSqlTypeName()) {
258268
case ARRAY:
259269
case MULTISET:
@@ -317,10 +327,27 @@ public static RelDataType toRelDataType(RelDataTypeFactory dataTypeFactory, Fiel
317327
return toCalciteRowType(schema, dataTypeFactory);
318328
case LOGICAL_TYPE:
319329
Schema.LogicalType<?, ?> logicalType = fieldType.getLogicalType();
330+
RelDataType relDataType;
320331
if (logicalType instanceof PassThroughLogicalType) {
321-
return toRelDataType(dataTypeFactory, logicalType.getBaseType());
332+
relDataType =
333+
toRelDataType(
334+
dataTypeFactory, logicalType.getBaseType().withNullable(fieldType.getNullable()));
335+
} else {
336+
relDataType = dataTypeFactory.createSqlType(toSqlTypeName(fieldType));
322337
}
323-
return dataTypeFactory.createSqlType(toSqlTypeName(fieldType));
338+
// For backward-compatibility, exclude logical types registered in
339+
// CALCITE_TO_BEAM_TYPE_MAPPING,
340+
// e.g., primitive types, date time types, etc.
341+
SqlTypeName typeName = relDataType.getSqlTypeName();
342+
if (typeName != null && !CALCITE_TO_BEAM_TYPE_MAPPING.containsKey(typeName)) {
343+
// register both nullable and non-nullable variants.
344+
boolean flipNullable = !relDataType.isNullable();
345+
LOGICAL_TYPE_REL_DATA_MAPPING.put(relDataType, fieldType);
346+
LOGICAL_TYPE_REL_DATA_MAPPING.put(
347+
dataTypeFactory.createTypeWithNullability(relDataType, flipNullable),
348+
fieldType.withNullable(flipNullable));
349+
}
350+
return relDataType;
324351
default:
325352
return dataTypeFactory.createSqlType(toSqlTypeName(fieldType));
326353
}

sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamComplexTypeTest.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
*/
1818
package org.apache.beam.sdk.extensions.sql;
1919

20+
import static org.junit.Assert.assertEquals;
21+
2022
import java.nio.charset.StandardCharsets;
2123
import java.time.LocalDate;
2224
import java.time.LocalDateTime;
@@ -792,6 +794,7 @@ public void testUnknownLogicalType() {
792794
.apply(SqlTransform.query("select * from PCOLLECTION"));
793795

794796
PAssert.that(outputRow).containsInAnyOrder(inputRow);
797+
assertEquals(inputRow.getSchema(), outputRow.getSchema());
795798
pipeline.run().waitUntilFinish(Duration.standardMinutes(1));
796799
}
797800
}

sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import java.util.Map;
2525
import java.util.stream.Collectors;
2626
import org.apache.beam.sdk.schemas.Schema;
27+
import org.apache.beam.sdk.schemas.logicaltypes.PassThroughLogicalType;
28+
import org.apache.beam.sdk.values.Row;
2729
import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.rel.type.RelDataType;
2830
import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.rel.type.RelDataTypeFactory;
2931
import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.rel.type.RelDataTypeSystem;
@@ -184,9 +186,8 @@ public void testToRelDataTypeWithRowBackedLogicalType() {
184186
Schema nestedSchema = Schema.builder().addField("nested_f1", Schema.FieldType.INT32).build();
185187
Schema.FieldType rowType = Schema.FieldType.row(nestedSchema);
186188

187-
Schema.LogicalType<org.apache.beam.sdk.values.Row, org.apache.beam.sdk.values.Row> logicalType =
188-
new org.apache.beam.sdk.schemas.logicaltypes.PassThroughLogicalType<
189-
org.apache.beam.sdk.values.Row>(
189+
Schema.LogicalType<Row, Row> logicalType =
190+
new PassThroughLogicalType<Row>(
190191
"RowBackedLogicalType", Schema.FieldType.STRING, "", rowType) {};
191192

192193
Schema.FieldType logicalFieldType = Schema.FieldType.logicalType(logicalType);

sdks/python/apache_beam/coders/coder_impl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,7 @@ def encode_to_stream(self, value, out, nested):
841841
out.write_bigendian_int16(value)
842842

843843
def decode_from_stream(self, in_stream, nested):
844-
# type: (create_InputStream, bool) -> float
844+
# type: (create_InputStream, bool) -> int
845845
return in_stream.read_bigendian_int16()
846846

847847
def estimate_size(self, unused_value, nested=False):
@@ -857,12 +857,12 @@ def encode_to_stream(self, value, out, nested):
857857
out.write_byte(value)
858858

859859
def decode_from_stream(self, in_stream, nested):
860-
# type: (create_InputStream, bool) -> float
860+
# type: (create_InputStream, bool) -> int
861861
return in_stream.read_byte()
862862

863863
def estimate_size(self, unused_value, nested=False):
864864
# type: (Any, bool) -> int
865-
# A short is encoded as 2 bytes, regardless of nesting.
865+
# A byte is encoded as 1 byte, regardless of nesting.
866866
return 1
867867

868868

sdks/python/apache_beam/transforms/sql_test.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -162,19 +162,6 @@ def test_row(self):
162162
| SqlTransform("SELECT a*a as s, LENGTH(b) AS c FROM PCOLLECTION"))
163163
assert_that(out, equal_to([(1, 1), (4, 1), (100, 2)]))
164164

165-
@staticmethod
166-
def recover_to_python_type(input):
167-
fields = []
168-
for field in input:
169-
print(field)
170-
if hasattr(field, 'type_byte') and hasattr(field, 'payload'):
171-
obj = coders.FastPrimitivesCoder().decode(
172-
field.type_byte.to_bytes() + field.payload)
173-
fields.append(obj)
174-
else:
175-
fields.append(field)
176-
return tuple(fields)
177-
178165
def test_row_user_type(self):
179166
with TestPipeline() as p:
180167
out = (
@@ -183,9 +170,7 @@ def test_row_user_type(self):
183170
UserTypeRow(1, Aribitrary("abc"), -1j),
184171
])
185172
| SqlTransform("SELECT arb, complex FROM PCOLLECTION")
186-
# TODO: recover to user type. Currently pipeline can run,
187-
# but elements returned back to Python are generated rows
188-
| beam.Map(self.recover_to_python_type))
173+
| beam.Map(tuple))
189174
assert_that(
190175
out,
191176
equal_to([(Aribitrary(1.0), 1 + 2.5j), (Aribitrary("abc"), -1j)]))

sdks/python/apache_beam/typehints/schemas.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@
106106
from apache_beam.utils.timestamp import Timestamp
107107

108108
PYTHON_ANY_URN = "beam:logical:pythonsdk_any:v1"
109+
_PYTHON_ANY_FIELD_TYPE_BYTE = "_pythonsdk_any_type_byte"
110+
_PYTHON_ANY_FIELD_PAYLOAD = "payload"
109111
_SCHEMA_OPTION_STATIC_ENCODING = "beam:option:row:static_encoding"
110112

111113
# Bi-directional mappings
@@ -257,6 +259,7 @@ def schema_field(
257259

258260

259261
def _python_any_schema_pb2():
262+
# A portable schema matches FastPrimitivesCoder encoded values
260263
return schema_pb2.FieldType(
261264
logical_type=schema_pb2.LogicalType(
262265
urn=PYTHON_ANY_URN,
@@ -266,11 +269,11 @@ def _python_any_schema_pb2():
266269
schema=schema_pb2.Schema(
267270
fields=[
268271
schema_pb2.Field(
269-
name="type_byte",
272+
name=_PYTHON_ANY_FIELD_TYPE_BYTE,
270273
type=schema_pb2.FieldType(
271274
atomic_type=schema_pb2.BYTE, nullable=False)),
272275
schema_pb2.Field(
273-
name="payload",
276+
name=_PYTHON_ANY_FIELD_PAYLOAD,
274277
type=schema_pb2.FieldType(
275278
atomic_type=schema_pb2.BYTES, nullable=False))
276279
],

0 commit comments

Comments
 (0)