Skip to content

Commit a615657

Browse files
authored
Allow Python user type pass through Beam SQL (#38206)
* Allow Python user type pass through Beam SQL * Complete pythonsdk_any logical type representation def. Otherwise Java side SchemaTranslation for this logical type would fail * Handle PassthroughLogicalType in Beam SQL. Allow Beam SQL treat PassthroughLogicalType as its base type * Fix nested bytes in Beam SQL * Introduce a schema option for compact encoding for static non-null schema * Preserve original logical type in Beam->Calcite->Beam trip
1 parent 0cdd1b4 commit a615657

16 files changed

Lines changed: 311 additions & 94 deletions

File tree

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
{
22
"comment": "Modify this file in a trivial way to cause this test suite to run ",
3-
"modification": 1
3+
"modification": 2
44
}

sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ public abstract class RowCoderGenerator {
113113

114114
private static final String CODERS_FIELD_NAME = "FIELD_CODERS";
115115
private static final String POSITIONS_FIELD_NAME = "FIELD_ENCODING_POSITIONS";
116+
private static final String SCHEMA_OPTION_STATIC_ENCODING = "beam:option:row:static_encoding";
116117

117118
static class WithStackTrace<T> {
118119
private final T value;
@@ -407,8 +408,13 @@ static void encodeDelegate(
407408
checkState(value.getFieldCount() == value.getSchema().getFieldCount());
408409
checkState(encodingPosToIndex.length == value.getFieldCount());
409410

411+
boolean staticEncoding =
412+
value.getSchema().getOptions().getValueOrDefault(SCHEMA_OPTION_STATIC_ENCODING, false);
413+
410414
// Encode the field count. This allows us to handle compatible schema changes.
411-
VAR_INT_CODER.encode(value.getFieldCount(), outputStream);
415+
if (!staticEncoding) {
416+
VAR_INT_CODER.encode(value.getFieldCount(), outputStream);
417+
}
412418

413419
if (hasNullableFields) {
414420
// If the row has null fields, extract the values out once so that both scanNullFields and
@@ -420,7 +426,9 @@ static void encodeDelegate(
420426
}
421427

422428
// Encode a bitmap for the null fields to save having to encode a bunch of nulls.
423-
NULL_LIST_CODER.encode(scanNullFields(fieldValues, encodingPosToIndex), outputStream);
429+
if (!staticEncoding) {
430+
NULL_LIST_CODER.encode(scanNullFields(fieldValues, encodingPosToIndex), outputStream);
431+
}
424432
for (int encodingPos = 0; encodingPos < fieldValues.length; ++encodingPos) {
425433
@Nullable Object fieldValue = fieldValues[encodingPosToIndex[encodingPos]];
426434
if (fieldValue != null) {
@@ -430,7 +438,9 @@ static void encodeDelegate(
430438
} else {
431439
// Otherwise, we know all fields are non-null, so the null list is always empty.
432440

433-
NULL_LIST_CODER.encode(EMPTY_BIT_SET, outputStream);
441+
if (!staticEncoding) {
442+
NULL_LIST_CODER.encode(EMPTY_BIT_SET, outputStream);
443+
}
434444
for (int encodingPos = 0; encodingPos < value.getFieldCount(); ++encodingPos) {
435445
@Nullable Object fieldValue = value.getValue(encodingPosToIndex[encodingPos]);
436446
if (fieldValue != null) {
@@ -511,9 +521,15 @@ public InstrumentedType prepare(InstrumentedType instrumentedType) {
511521
static Row decodeDelegate(
512522
Schema schema, Coder[] coders, int[] encodingPosToIndex, InputStream inputStream)
513523
throws IOException {
514-
int fieldCount = VAR_INT_CODER.decode(inputStream);
515-
516-
BitSet nullFields = NULL_LIST_CODER.decode(inputStream);
524+
int fieldCount;
525+
BitSet nullFields;
526+
if (schema.getOptions().getValueOrDefault(SCHEMA_OPTION_STATIC_ENCODING, false)) {
527+
fieldCount = schema.getFieldCount();
528+
nullFields = new BitSet();
529+
} else {
530+
fieldCount = VAR_INT_CODER.decode(inputStream);
531+
nullFields = NULL_LIST_CODER.decode(inputStream);
532+
}
517533
Object[] fieldValues = new Object[coders.length];
518534
for (int encodingPos = 0; encodingPos < fieldCount; ++encodingPos) {
519535
// In the case of a schema change going backwards, fieldCount might be > coders.length,

sdks/java/core/src/test/java/org/apache/beam/sdk/coders/RowCoderTest.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,4 +626,23 @@ public void testEncodingPositionRemoveFields() throws Exception {
626626
Row decoded = RowCoder.of(schema2).decode(new ByteArrayInputStream(os.toByteArray()));
627627
assertEquals(expected, decoded);
628628
}
629+
630+
@Test
631+
public void testStaticEncoding() throws Exception {
632+
Schema schema =
633+
Schema.builder()
634+
.addInt32Field("f_int32")
635+
.addStringField("f_string")
636+
.setOptions(
637+
Schema.Options.builder()
638+
.setOption("beam:option:row:static_encoding", FieldType.BOOLEAN, true)
639+
.build())
640+
.build();
641+
Row row = Row.withSchema(schema).addValues(42, "hello world!").build();
642+
ByteArrayOutputStream bos = new ByteArrayOutputStream();
643+
RowCoder.of(schema).encode(row, bos);
644+
assertEquals(14, bos.toByteArray().length);
645+
646+
CoderProperties.coderDecodeEncodeEqual(RowCoder.of(schema), row);
647+
}
629648
}

sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java

Lines changed: 69 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,10 @@ public PCollection<Row> expand(PCollectionList<Row> pinput) {
221221
BeamSqlPipelineOptions options =
222222
pinput.getPipeline().getOptions().as(BeamSqlPipelineOptions.class);
223223

224+
String builderString = builder.toBlock().toString();
224225
CalcFn calcFn =
225226
new CalcFn(
226-
builder.toBlock().toString(),
227+
builderString,
227228
outputSchema,
228229
options.getVerifyRowValues(),
229230
getJarPaths(program),
@@ -502,120 +503,109 @@ FieldAccessDescriptor getFieldAccess() {
502503
@Override
503504
public Expression field(BlockBuilder list, int index, Type storageType) {
504505
this.referencedColumns.add(index);
505-
return getBeamField(list, index, input, inputSchema);
506+
return getBeamField(list, index, input, inputSchema, true);
506507
}
507508

508509
// Read field from Beam Row
509510
private static Expression getBeamField(
510-
BlockBuilder list, int index, Expression input, Schema schema) {
511+
BlockBuilder list, int index, Expression input, Schema schema, boolean useByteString) {
511512
if (index >= schema.getFieldCount() || index < 0) {
512513
throw new IllegalArgumentException("Unable to find value #" + index);
513514
}
514515

515516
final Expression expression = list.append(list.newName("current"), input);
516-
517517
final Field field = schema.getField(index);
518518
final FieldType fieldType = field.getType();
519519
final Expression fieldName = Expressions.constant(field.getName());
520+
Expression value = getBeamField(list, expression, fieldName, fieldType);
521+
522+
return toCalciteValue(value, fieldType, useByteString);
523+
}
524+
525+
private static Expression getBeamField(
526+
BlockBuilder list, Expression expression, Expression fieldName, FieldType fieldType) {
520527
final Expression value;
521528
switch (fieldType.getTypeName()) {
522529
case BYTE:
523-
value = Expressions.call(expression, "getByte", fieldName);
524-
break;
530+
return Expressions.call(expression, "getByte", fieldName);
525531
case INT16:
526-
value = Expressions.call(expression, "getInt16", fieldName);
527-
break;
532+
return Expressions.call(expression, "getInt16", fieldName);
528533
case INT32:
529-
value = Expressions.call(expression, "getInt32", fieldName);
530-
break;
534+
return Expressions.call(expression, "getInt32", fieldName);
531535
case INT64:
532-
value = Expressions.call(expression, "getInt64", fieldName);
533-
break;
536+
return Expressions.call(expression, "getInt64", fieldName);
534537
case DECIMAL:
535-
value = Expressions.call(expression, "getDecimal", fieldName);
536-
break;
538+
return Expressions.call(expression, "getDecimal", fieldName);
537539
case FLOAT:
538-
value = Expressions.call(expression, "getFloat", fieldName);
539-
break;
540+
return Expressions.call(expression, "getFloat", fieldName);
540541
case DOUBLE:
541-
value = Expressions.call(expression, "getDouble", fieldName);
542-
break;
542+
return Expressions.call(expression, "getDouble", fieldName);
543543
case STRING:
544-
value = Expressions.call(expression, "getString", fieldName);
545-
break;
544+
return Expressions.call(expression, "getString", fieldName);
546545
case DATETIME:
547-
value = Expressions.call(expression, "getDateTime", fieldName);
548-
break;
546+
return Expressions.call(expression, "getDateTime", fieldName);
549547
case BOOLEAN:
550-
value = Expressions.call(expression, "getBoolean", fieldName);
551-
break;
548+
return Expressions.call(expression, "getBoolean", fieldName);
552549
case BYTES:
553-
value = Expressions.call(expression, "getBytes", fieldName);
554-
break;
550+
return Expressions.call(expression, "getBytes", fieldName);
555551
case ARRAY:
556-
value = Expressions.call(expression, "getArray", fieldName);
557-
break;
552+
return Expressions.call(expression, "getArray", fieldName);
558553
case MAP:
559-
value = Expressions.call(expression, "getMap", fieldName);
560-
break;
554+
return Expressions.call(expression, "getMap", fieldName);
561555
case ROW:
562-
value = Expressions.call(expression, "getRow", fieldName);
563-
break;
556+
return Expressions.call(expression, "getRow", fieldName);
564557
case ITERABLE:
565-
value = Expressions.call(expression, "getIterable", fieldName);
566-
break;
558+
return Expressions.call(expression, "getIterable", fieldName);
567559
case LOGICAL_TYPE:
568-
String identifier = fieldType.getLogicalType().getIdentifier();
560+
LogicalType logicalType = fieldType.getLogicalType();
561+
String identifier = logicalType.getIdentifier();
569562
if (FixedString.IDENTIFIER.equals(identifier)
570563
|| VariableString.IDENTIFIER.equals(identifier)) {
571-
value = Expressions.call(expression, "getString", fieldName);
564+
return Expressions.call(expression, "getString", fieldName);
572565
} else if (FixedBytes.IDENTIFIER.equals(identifier)
573566
|| VariableBytes.IDENTIFIER.equals(identifier)) {
574-
value = Expressions.call(expression, "getBytes", fieldName);
567+
return Expressions.call(expression, "getBytes", fieldName);
575568
} else if (TimeWithLocalTzType.IDENTIFIER.equals(identifier)) {
576-
value = Expressions.call(expression, "getDateTime", fieldName);
569+
return Expressions.call(expression, "getDateTime", fieldName);
577570
} else if (SqlTypes.DATE.getIdentifier().equals(identifier)) {
578-
value =
579-
Expressions.convert_(
580-
Expressions.call(
581-
expression,
582-
"getLogicalTypeValue",
583-
fieldName,
584-
Expressions.constant(LocalDate.class)),
585-
LocalDate.class);
571+
return Expressions.convert_(
572+
Expressions.call(
573+
expression,
574+
"getLogicalTypeValue",
575+
fieldName,
576+
Expressions.constant(LocalDate.class)),
577+
LocalDate.class);
586578
} else if (SqlTypes.TIME.getIdentifier().equals(identifier)) {
587-
value =
588-
Expressions.convert_(
589-
Expressions.call(
590-
expression,
591-
"getLogicalTypeValue",
592-
fieldName,
593-
Expressions.constant(LocalTime.class)),
594-
LocalTime.class);
579+
return Expressions.convert_(
580+
Expressions.call(
581+
expression,
582+
"getLogicalTypeValue",
583+
fieldName,
584+
Expressions.constant(LocalTime.class)),
585+
LocalTime.class);
595586
} else if (SqlTypes.DATETIME.getIdentifier().equals(identifier)) {
596-
value =
597-
Expressions.convert_(
598-
Expressions.call(
599-
expression,
600-
"getLogicalTypeValue",
601-
fieldName,
602-
Expressions.constant(LocalDateTime.class)),
603-
LocalDateTime.class);
587+
return Expressions.convert_(
588+
Expressions.call(
589+
expression,
590+
"getLogicalTypeValue",
591+
fieldName,
592+
Expressions.constant(LocalDateTime.class)),
593+
LocalDateTime.class);
604594
} else if (FixedPrecisionNumeric.IDENTIFIER.equals(identifier)) {
605-
value = Expressions.call(expression, "getDecimal", fieldName);
595+
return Expressions.call(expression, "getDecimal", fieldName);
596+
} else if (logicalType instanceof PassThroughLogicalType) {
597+
return getBeamField(list, expression, fieldName, logicalType.getBaseType());
606598
} else {
607599
throw new UnsupportedOperationException("Unable to get logical type " + identifier);
608600
}
609-
break;
610601
default:
611602
throw new UnsupportedOperationException("Unable to get " + fieldType.getTypeName());
612603
}
613-
614-
return toCalciteValue(value, fieldType);
615604
}
616605

617606
// Value conversion: Beam => Calcite
618-
private static Expression toCalciteValue(Expression value, FieldType fieldType) {
607+
private static Expression toCalciteValue(
608+
Expression value, FieldType fieldType, boolean useByteString) {
619609
switch (fieldType.getTypeName()) {
620610
case BYTE:
621611
return Expressions.convert_(value, Byte.class);
@@ -642,7 +632,10 @@ private static Expression toCalciteValue(Expression value, FieldType fieldType)
642632
Expressions.call(Expressions.convert_(value, AbstractInstant.class), "getMillis"));
643633
case BYTES:
644634
return nullOr(
645-
value, Expressions.new_(ByteString.class, Expressions.convert_(value, byte[].class)));
635+
value,
636+
useByteString
637+
? Expressions.new_(ByteString.class, Expressions.convert_(value, byte[].class))
638+
: Expressions.convert_(value, byte[].class));
646639
case ARRAY:
647640
case ITERABLE:
648641
return nullOr(value, toCalciteList(value, fieldType.getCollectionElementType()));
@@ -651,7 +644,8 @@ private static Expression toCalciteValue(Expression value, FieldType fieldType)
651644
case ROW:
652645
return nullOr(value, toCalciteRow(value, fieldType.getRowSchema()));
653646
case LOGICAL_TYPE:
654-
String identifier = fieldType.getLogicalType().getIdentifier();
647+
LogicalType logicalType = fieldType.getLogicalType();
648+
String identifier = logicalType.getIdentifier();
655649
if (FixedString.IDENTIFIER.equals(identifier)
656650
|| VariableString.IDENTIFIER.equals(identifier)) {
657651
return Expressions.convert_(value, String.class);
@@ -692,6 +686,8 @@ private static Expression toCalciteValue(Expression value, FieldType fieldType)
692686
return nullOr(value, returnValue);
693687
} else if (FixedPrecisionNumeric.IDENTIFIER.equals(identifier)) {
694688
return Expressions.convert_(value, BigDecimal.class);
689+
} else if (logicalType instanceof PassThroughLogicalType) {
690+
return toCalciteValue(value, logicalType.getBaseType(), useByteString);
695691
} else {
696692
throw new UnsupportedOperationException("Unable to convert logical type " + identifier);
697693
}
@@ -704,7 +700,7 @@ private static Expression toCalciteList(Expression input, FieldType elementType)
704700
ParameterExpression value = Expressions.parameter(Object.class);
705701

706702
BlockBuilder block = new BlockBuilder();
707-
block.add(toCalciteValue(value, elementType));
703+
block.add(toCalciteValue(value, elementType, false));
708704

709705
return Expressions.new_(
710706
WrappedList.class,
@@ -722,7 +718,7 @@ private static Expression toCalciteMap(Expression input, FieldType mapValueType)
722718
ParameterExpression value = Expressions.parameter(Object.class);
723719

724720
BlockBuilder block = new BlockBuilder();
725-
block.add(toCalciteValue(value, mapValueType));
721+
block.add(toCalciteValue(value, mapValueType, false));
726722

727723
return Expressions.new_(
728724
WrappedMap.class,
@@ -745,7 +741,8 @@ private static Expression toCalciteRow(Expression input, Schema schema) {
745741

746742
for (int i = 0; i < schema.getFieldCount(); i++) {
747743
BlockBuilder list = new BlockBuilder(/* optimizing= */ false, body);
748-
Expression returnValue = getBeamField(list, i, row, schema);
744+
// instruct conversion of BYTES to byte[], required by BeamJavaTypeFactory
745+
Expression returnValue = getBeamField(list, i, row, schema, false);
749746

750747
list.append(returnValue);
751748

0 commit comments

Comments
 (0)