Skip to content

Commit 89de477

Browse files
authored
Add spark value writers for Date/Timestamp/TimestampNTZ and Struct types (#7424)
Add value writers for all teh types that we can convert to spark. Previously we have been able to convert the type but failed at writing values Signed-off-by: Robert Kruszewski <github@robertk.io>
1 parent afcdb9a commit 89de477

5 files changed

Lines changed: 127 additions & 16 deletions

File tree

java/testfiles/Cargo.lock

Lines changed: 12 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

java/vortex-jni/src/main/java/dev/vortex/api/expressions/IsNull.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ public static IsNull parse(byte[] metadata, List<Expression> children) {
3535
"IsNull expression must have exactly one child, found: " + children.size());
3636
}
3737
if (metadata.length > 0) {
38-
throw new IllegalArgumentException(
39-
"IsNull expression must not have metadata, found: " + metadata.length);
38+
throw new IllegalArgumentException("IsNull expression must not have metadata, found: " + metadata.length);
4039
}
4140
return new IsNull(children.get(0));
4241
}

java/vortex-spark/src/main/java/dev/vortex/spark/write/SparkToArrowSchema.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ private static ArrowType convertType(DataType sparkType) {
100100
return new ArrowType.Date(DateUnit.DAY);
101101
} else if (sparkType instanceof TimestampType) {
102102
return new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC");
103+
} else if (sparkType instanceof TimestampNTZType) {
104+
return new ArrowType.Timestamp(TimeUnit.MICROSECOND, null);
103105
} else if (sparkType instanceof DecimalType) {
104106
DecimalType decimal = (DecimalType) sparkType;
105107
return new ArrowType.Decimal(decimal.precision(), decimal.scale(), 128);

java/vortex-spark/src/main/java/dev/vortex/spark/write/VortexDataWriter.java

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
import dev.vortex.relocated.org.apache.arrow.memory.BufferAllocator;
1111
import dev.vortex.relocated.org.apache.arrow.memory.RootAllocator;
1212
import dev.vortex.relocated.org.apache.arrow.vector.*;
13+
import dev.vortex.relocated.org.apache.arrow.vector.FieldVector;
1314
import dev.vortex.relocated.org.apache.arrow.vector.VectorSchemaRoot;
1415
import dev.vortex.relocated.org.apache.arrow.vector.complex.ListVector;
16+
import dev.vortex.relocated.org.apache.arrow.vector.complex.StructVector;
1517
import dev.vortex.spark.SparkTypes;
1618
import java.io.IOException;
1719
import java.nio.file.Files;
@@ -205,16 +207,23 @@ private void populateVector(
205207
if (bytes != null) {
206208
((VarBinaryVector) vector).setSafe(rowIndex, bytes);
207209
}
208-
} else if (dataType instanceof DecimalType) {
209-
DecimalType decType = (DecimalType) dataType;
210+
} else if (dataType instanceof DateType) {
211+
((DateDayVector) vector).setSafe(rowIndex, row.getInt(fieldIndex));
212+
} else if (dataType instanceof TimestampType) {
213+
((TimeStampMicroTZVector) vector).setSafe(rowIndex, row.getLong(fieldIndex));
214+
} else if (dataType instanceof TimestampNTZType) {
215+
((TimeStampMicroVector) vector).setSafe(rowIndex, row.getLong(fieldIndex));
216+
} else if (dataType instanceof DecimalType decType) {
210217
if (decType.precision() <= 38) {
211218
// Use Decimal type from InternalRow
212219
java.math.BigDecimal decimal = row.getDecimal(fieldIndex, decType.precision(), decType.scale())
213220
.toJavaBigDecimal();
214221
((DecimalVector) vector).setSafe(rowIndex, decimal);
215222
}
216-
} else if (dataType instanceof ArrayType) {
217-
ArrayType arrayType = (ArrayType) dataType;
223+
} else if (dataType instanceof StructType structType) {
224+
populateStructVector(
225+
(StructVector) vector, structType, row.getStruct(fieldIndex, structType.fields().length), rowIndex);
226+
} else if (dataType instanceof ArrayType arrayType) {
218227
ArrayData data = row.getArray(fieldIndex);
219228
ListVector listVector = ((ListVector) vector);
220229
int writtenElements = listVector.getElementEndIndex(listVector.getLastSet());
@@ -229,6 +238,20 @@ private void populateVector(
229238
}
230239
}
231240

241+
private void populateStructVector(StructVector vector, StructType dataType, InternalRow row, int rowIndex) {
242+
vector.setIndexDefined(rowIndex);
243+
244+
StructField[] fields = dataType.fields();
245+
for (int fieldIndex = 0; fieldIndex < fields.length; fieldIndex++) {
246+
FieldVector childVector = (FieldVector) vector.getVectorById(fieldIndex);
247+
if (row.isNullAt(fieldIndex)) {
248+
childVector.setNull(rowIndex);
249+
continue;
250+
}
251+
populateVector(childVector, fields[fieldIndex].dataType(), row, fieldIndex, rowIndex);
252+
}
253+
}
254+
232255
/**
233256
* Commits the write operation and returns a commit message.
234257
* <p>

java/vortex-spark/src/test/java/dev/vortex/spark/VortexDataSourceWriteTest.java

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
package dev.vortex.spark;
55

6+
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
67
import static org.junit.jupiter.api.Assertions.assertEquals;
78
import static org.junit.jupiter.api.Assertions.assertTrue;
89

@@ -324,6 +325,73 @@ public void testSpecialCharactersAndNulls() throws IOException {
324325
assertEquals("special!@#$%^&*()", specialRows.first().getString(1));
325326
}
326327

328+
@Test
329+
@DisplayName("Write and read date, timestamp, and nested struct columns")
330+
public void testWriteAndReadTemporalAndStructColumns() throws IOException {
331+
Dataset<Row> originalDf = spark.range(0, 2)
332+
.selectExpr(
333+
"cast(id as int) as id",
334+
"CASE WHEN id = 0 THEN CAST('2024-01-02' AS DATE) ELSE CAST('2024-02-03' AS DATE) END AS event_date",
335+
"CASE WHEN id = 0 THEN CAST('2024-01-02 03:04:05.123456' AS TIMESTAMP) "
336+
+ "ELSE CAST('2024-02-03 04:05:06.654321' AS TIMESTAMP) END AS event_ts",
337+
"named_struct("
338+
+ "'event_date', CASE WHEN id = 0 THEN CAST('2024-01-02' AS DATE) ELSE CAST('2024-02-03' AS DATE) END, "
339+
+ "'event_ts', CASE WHEN id = 0 THEN CAST('2024-01-02 03:04:05.123456' AS TIMESTAMP) "
340+
+ "ELSE CAST('2024-02-03 04:05:06.654321' AS TIMESTAMP) END, "
341+
+ "'label', CASE WHEN id = 0 THEN 'alpha' ELSE 'beta' END"
342+
+ ") AS payload");
343+
344+
Path outputPath = tempDir.resolve("temporal_struct_output");
345+
originalDf
346+
.write()
347+
.format("vortex")
348+
.option("path", outputPath.toUri().toString())
349+
.mode(SaveMode.Overwrite)
350+
.save();
351+
352+
Dataset<Row> readDf = spark.read()
353+
.format("vortex")
354+
.option("path", outputPath.toUri().toString())
355+
.load();
356+
357+
List<String> expectedRows = List.of(
358+
"{\"id\":0,\"event_date\":\"2024-01-02\",\"event_ts\":\"2024-01-02 03:04:05.123456\","
359+
+ "\"payload_event_date\":\"2024-01-02\",\"payload_event_ts\":\"2024-01-02 03:04:05.123456\","
360+
+ "\"payload_label\":\"alpha\"}",
361+
"{\"id\":1,\"event_date\":\"2024-02-03\",\"event_ts\":\"2024-02-03 04:05:06.654321\","
362+
+ "\"payload_event_date\":\"2024-02-03\",\"payload_event_ts\":\"2024-02-03 04:05:06.654321\","
363+
+ "\"payload_label\":\"beta\"}");
364+
365+
assertEquals(DataTypes.DateType, readDf.schema().fields()[1].dataType());
366+
assertEquals(DataTypes.TimestampType, readDf.schema().fields()[2].dataType());
367+
assertTrue(readDf.schema().fields()[3].dataType() instanceof StructType);
368+
assertEquals(expectedRows, projectTemporalAndStructRows(readDf));
369+
}
370+
371+
@Test
372+
@DisplayName("Write TimestampNTZ columns and nested structs")
373+
public void testWriteTimestampNtzColumns() throws IOException {
374+
Dataset<Row> timestampNtzDf = spark.range(0, 2)
375+
.selectExpr(
376+
"cast(id as int) as id",
377+
"CASE WHEN id = 0 THEN CAST('2024-01-02 03:04:05.123456' AS TIMESTAMP_NTZ) "
378+
+ "ELSE CAST(NULL AS TIMESTAMP_NTZ) END AS event_ntz",
379+
"named_struct("
380+
+ "'event_ntz', CASE WHEN id = 0 THEN CAST('2024-01-02 03:04:05.123456' AS TIMESTAMP_NTZ) "
381+
+ "ELSE CAST('2024-02-03 04:05:06.654321' AS TIMESTAMP_NTZ) END"
382+
+ ") AS payload");
383+
384+
Path outputPath = tempDir.resolve("timestamp_ntz_output");
385+
assertDoesNotThrow(() -> timestampNtzDf
386+
.write()
387+
.format("vortex")
388+
.option("path", outputPath.toUri().toString())
389+
.mode(SaveMode.Overwrite)
390+
.save());
391+
392+
assertTrue(!findVortexFiles(outputPath).isEmpty(), "TimestampNTZ write should create Vortex files");
393+
}
394+
327395
/**
328396
* Creates a test DataFrame with monotonically increasing integers
329397
* and their string representations.
@@ -337,6 +405,23 @@ private Dataset<Row> createTestDataFrame(int numRows) {
337405
"array('Alpha', 'Bravo', 'Charlie') AS elements");
338406
}
339407

408+
private List<String> projectTemporalAndStructRows(Dataset<Row> df) {
409+
return df
410+
.orderBy("id")
411+
.selectExpr("to_json(named_struct("
412+
+ "'id', id, "
413+
+ "'event_date', cast(event_date as string), "
414+
+ "'event_ts', date_format(event_ts, 'yyyy-MM-dd HH:mm:ss.SSSSSS'), "
415+
+ "'payload_event_date', cast(payload.event_date as string), "
416+
+ "'payload_event_ts', date_format(payload.event_ts, 'yyyy-MM-dd HH:mm:ss.SSSSSS'), "
417+
+ "'payload_label', payload.label"
418+
+ ")) as json")
419+
.collectAsList()
420+
.stream()
421+
.map(row -> row.getString(0))
422+
.collect(Collectors.toList());
423+
}
424+
340425
/**
341426
* Finds all Vortex files in the given directory.
342427
*/

0 commit comments

Comments
 (0)