Skip to content

Commit faf8a1a

Browse files
committed
Preserve original logical type in Beam->Calcite->Beam trip
1 parent a8fa6db commit faf8a1a

6 files changed

Lines changed: 45 additions & 26 deletions

File tree

sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.lang.reflect.Type;
2323
import java.util.Date;
2424
import java.util.Map;
25+
import java.util.WeakHashMap;
2526
import java.util.stream.IntStream;
2627
import org.apache.beam.sdk.schemas.Schema;
2728
import org.apache.beam.sdk.schemas.Schema.FieldType;
@@ -169,6 +170,11 @@ public static boolean isStringType(FieldType fieldType) {
169170
FieldType.DATETIME, SqlTypeName.TIMESTAMP,
170171
FieldType.STRING, SqlTypeName.VARCHAR);
171172

173+
// Use a weak hash map to preserve logical type in output schema for a full
174+
// Beam FieldType->Calcite Type->Beam FieldType trip
175+
private static final Map<RelDataType, FieldType> LOGICAL_TYPE_REL_DATA_MAPPING =
176+
new WeakHashMap<>();
177+
172178
/** Generate {@link Schema} from {@code RelDataType} which is used to create table. */
173179
public static Schema toSchema(RelDataType tableInfo) {
174180
return tableInfo.getFieldList().stream().map(CalciteUtils::toField).collect(Schema.toSchema());
@@ -254,6 +260,9 @@ public static Schema.Field toField(String name, RelDataType calciteType) {
254260
}
255261

256262
public static FieldType toFieldType(RelDataType calciteType) {
263+
if (LOGICAL_TYPE_REL_DATA_MAPPING.containsKey(calciteType)) {
264+
return LOGICAL_TYPE_REL_DATA_MAPPING.get(calciteType);
265+
}
257266
switch (calciteType.getSqlTypeName()) {
258267
case ARRAY:
259268
case MULTISET:
@@ -317,10 +326,27 @@ public static RelDataType toRelDataType(RelDataTypeFactory dataTypeFactory, Fiel
317326
return toCalciteRowType(schema, dataTypeFactory);
318327
case LOGICAL_TYPE:
319328
Schema.LogicalType<?, ?> logicalType = fieldType.getLogicalType();
329+
RelDataType relDataType;
320330
if (logicalType instanceof PassThroughLogicalType) {
321-
return toRelDataType(dataTypeFactory, logicalType.getBaseType());
331+
relDataType =
332+
toRelDataType(
333+
dataTypeFactory, logicalType.getBaseType().withNullable(fieldType.getNullable()));
334+
} else {
335+
relDataType = dataTypeFactory.createSqlType(toSqlTypeName(fieldType));
322336
}
323-
return dataTypeFactory.createSqlType(toSqlTypeName(fieldType));
337+
// For backward-compatibility, exclude logical types registered in
338+
// CALCITE_TO_BEAM_TYPE_MAPPING,
339+
// e.g., primitive types, date time types, etc.
340+
SqlTypeName typeName = relDataType.getSqlTypeName();
341+
if (typeName != null && !CALCITE_TO_BEAM_TYPE_MAPPING.containsKey(typeName)) {
342+
// register both nullable and non-nullable variants
343+
boolean flipNullable = !relDataType.isNullable();
344+
LOGICAL_TYPE_REL_DATA_MAPPING.put(relDataType, fieldType);
345+
LOGICAL_TYPE_REL_DATA_MAPPING.put(
346+
dataTypeFactory.createTypeWithNullability(relDataType, flipNullable),
347+
fieldType.withNullable(flipNullable));
348+
}
349+
return relDataType;
324350
default:
325351
return dataTypeFactory.createSqlType(toSqlTypeName(fieldType));
326352
}

sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamComplexTypeTest.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
*/
1818
package org.apache.beam.sdk.extensions.sql;
1919

20+
import static org.junit.Assert.assertEquals;
21+
2022
import java.nio.charset.StandardCharsets;
2123
import java.time.LocalDate;
2224
import java.time.LocalDateTime;
@@ -792,6 +794,7 @@ public void testUnknownLogicalType() {
792794
.apply(SqlTransform.query("select * from PCOLLECTION"));
793795

794796
PAssert.that(outputRow).containsInAnyOrder(inputRow);
797+
assertEquals(inputRow.getSchema(), outputRow.getSchema());
795798
pipeline.run().waitUntilFinish(Duration.standardMinutes(1));
796799
}
797800
}

sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import java.util.Map;
2525
import java.util.stream.Collectors;
2626
import org.apache.beam.sdk.schemas.Schema;
27+
import org.apache.beam.sdk.schemas.logicaltypes.PassThroughLogicalType;
28+
import org.apache.beam.sdk.values.Row;
2729
import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.rel.type.RelDataType;
2830
import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.rel.type.RelDataTypeFactory;
2931
import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.rel.type.RelDataTypeSystem;
@@ -184,9 +186,8 @@ public void testToRelDataTypeWithRowBackedLogicalType() {
184186
Schema nestedSchema = Schema.builder().addField("nested_f1", Schema.FieldType.INT32).build();
185187
Schema.FieldType rowType = Schema.FieldType.row(nestedSchema);
186188

187-
Schema.LogicalType<org.apache.beam.sdk.values.Row, org.apache.beam.sdk.values.Row> logicalType =
188-
new org.apache.beam.sdk.schemas.logicaltypes.PassThroughLogicalType<
189-
org.apache.beam.sdk.values.Row>(
189+
Schema.LogicalType<Row, Row> logicalType =
190+
new PassThroughLogicalType<Row>(
190191
"RowBackedLogicalType", Schema.FieldType.STRING, "", rowType) {};
191192

192193
Schema.FieldType logicalFieldType = Schema.FieldType.logicalType(logicalType);

sdks/python/apache_beam/coders/coder_impl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,7 @@ def encode_to_stream(self, value, out, nested):
841841
out.write_bigendian_int16(value)
842842

843843
def decode_from_stream(self, in_stream, nested):
844-
# type: (create_InputStream, bool) -> float
844+
# type: (create_InputStream, bool) -> int
845845
return in_stream.read_bigendian_int16()
846846

847847
def estimate_size(self, unused_value, nested=False):
@@ -857,12 +857,12 @@ def encode_to_stream(self, value, out, nested):
857857
out.write_byte(value)
858858

859859
def decode_from_stream(self, in_stream, nested):
860-
# type: (create_InputStream, bool) -> float
860+
# type: (create_InputStream, bool) -> int
861861
return in_stream.read_byte()
862862

863863
def estimate_size(self, unused_value, nested=False):
864864
# type: (Any, bool) -> int
865-
# A short is encoded as 2 bytes, regardless of nesting.
865+
# A byte is encoded as 1 byte, regardless of nesting.
866866
return 1
867867

868868

sdks/python/apache_beam/transforms/sql_test.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -162,19 +162,6 @@ def test_row(self):
162162
| SqlTransform("SELECT a*a as s, LENGTH(b) AS c FROM PCOLLECTION"))
163163
assert_that(out, equal_to([(1, 1), (4, 1), (100, 2)]))
164164

165-
@staticmethod
166-
def recover_to_python_type(input):
167-
fields = []
168-
for field in input:
169-
print(field)
170-
if hasattr(field, 'type_byte') and hasattr(field, 'payload'):
171-
obj = coders.FastPrimitivesCoder().decode(
172-
field.type_byte.to_bytes() + field.payload)
173-
fields.append(obj)
174-
else:
175-
fields.append(field)
176-
return tuple(fields)
177-
178165
def test_row_user_type(self):
179166
with TestPipeline() as p:
180167
out = (
@@ -183,9 +170,7 @@ def test_row_user_type(self):
183170
UserTypeRow(1, Aribitrary("abc"), -1j),
184171
])
185172
| SqlTransform("SELECT arb, complex FROM PCOLLECTION")
186-
# TODO: recover to user type. Currently pipeline can run,
187-
# but elements returned back to Python are generated rows
188-
| beam.Map(self.recover_to_python_type))
173+
| beam.Map(tuple))
189174
assert_that(
190175
out,
191176
equal_to([(Aribitrary(1.0), 1 + 2.5j), (Aribitrary("abc"), -1j)]))

sdks/python/apache_beam/typehints/schemas.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,11 @@
106106
from apache_beam.utils.timestamp import Timestamp
107107

108108
PYTHON_ANY_URN = "beam:logical:pythonsdk_any:v1"
109+
_PYTHON_ANY_FIELD_TYPE_BYTE = "_pythonsdk_any_type_byte"
110+
_PYTHON_ANY_FIELD_PAYLOAD = "payload"
109111
_SCHEMA_OPTION_STATIC_ENCODING = "beam:option:row:static_encoding"
110112

113+
111114
# Bi-directional mappings
112115
_PRIMITIVES = (
113116
(np.int8, schema_pb2.BYTE),
@@ -257,6 +260,7 @@ def schema_field(
257260

258261

259262
def _python_any_schema_pb2():
263+
# A portable schema matches FastPrimitivesCoder encoded values
260264
return schema_pb2.FieldType(
261265
logical_type=schema_pb2.LogicalType(
262266
urn=PYTHON_ANY_URN,
@@ -266,11 +270,11 @@ def _python_any_schema_pb2():
266270
schema=schema_pb2.Schema(
267271
fields=[
268272
schema_pb2.Field(
269-
name="type_byte",
273+
name=_PYTHON_ANY_FIELD_TYPE_BYTE,
270274
type=schema_pb2.FieldType(
271275
atomic_type=schema_pb2.BYTE, nullable=False)),
272276
schema_pb2.Field(
273-
name="payload",
277+
name=_PYTHON_ANY_FIELD_PAYLOAD,
274278
type=schema_pb2.FieldType(
275279
atomic_type=schema_pb2.BYTES, nullable=False))
276280
],

0 commit comments

Comments
 (0)