Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions java/vortex-jni/src/main/java/dev/vortex/api/DType.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,15 @@ public interface DType extends AutoCloseable {
List<DType> getFieldTypes();

/**
* Get the element type for a LIST type.
* Get the element type for a LIST or FIXED_SIZE_LIST type.
*/
DType getElementType();

/**
* Get the fixed size for a FIXED_SIZE_LIST type.
*/
int getFixedSizeListSize();

/**
* Checks if this data type represents a date.
*
Expand Down Expand Up @@ -234,6 +239,19 @@ static DType newList(DType element, boolean isNullable) {
return new JNIDType(NativeDTypeMethods.newList(jniType.getPointer(), isNullable), true);
}

/**
* Create a new FixedSizeList data type.
*
* @param element DType of the list elements
* @param size The fixed size of each list
* @param isNullable True if the values can be null
* @return The new DType instance, allocated in native heap memory
*/
static DType newFixedSizeList(DType element, int size, boolean isNullable) {
JNIDType jniType = (JNIDType) element;
return new JNIDType(NativeDTypeMethods.newFixedSizeList(jniType.getPointer(), size, isNullable), true);
}

/**
* Create a new Struct data type.
*
Expand Down Expand Up @@ -467,12 +485,17 @@ enum Variant {
* Decimal type for precise numeric values
*/
DECIMAL,

/**
* Fixed-size list type containing a fixed number of elements of a single type
*/
FIXED_SIZE_LIST,
;

/**
* Converts a byte value to the corresponding Variant enum.
*
* @param variant the byte value representing the variant (0-18)
* @param variant the byte value representing the variant (0-19)
* @return the corresponding {@link Variant} enum value
* @throws RuntimeException if the variant value is not recognized
*/
Expand Down Expand Up @@ -516,6 +539,8 @@ public static Variant from(byte variant) {
return EXTENSION;
case 18:
return DECIMAL;
case 19:
return FIXED_SIZE_LIST;
default:
throw new IllegalArgumentException("Unknown DType variant: " + variant);
}
Expand Down
5 changes: 5 additions & 0 deletions java/vortex-jni/src/main/java/dev/vortex/jni/JNIDType.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ public DType getElementType() {
return new JNIDType(NativeDTypeMethods.getElementType(pointer.getAsLong()));
}

@Override
public int getFixedSizeListSize() {
return NativeDTypeMethods.getFixedSizeListSize(pointer.getAsLong());
}

@Override
public boolean isDate() {
return NativeDTypeMethods.isDate(pointer.getAsLong());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ private NativeDTypeMethods() {}
*/
public static native long newList(long elementTypePtr, boolean isNullable);

/**
* Create a new native DType for a FixedSizeList type. The created object lives in native memory.
*
* @param elementTypePtr A native pointer to a DType containing the type of the elements
* @param size The fixed size of each list
* @param isNullable true if the values can be null
* @return Pointer to a new heap-allocated {@code DType}.
*/
public static native long newFixedSizeList(long elementTypePtr, int size, boolean isNullable);

/**
* Create a new native DType for a Struct type. The created object lives in native memory.
*
Expand Down Expand Up @@ -154,6 +164,8 @@ private NativeDTypeMethods() {}

public static native long getElementType(long pointer);

public static native int getFixedSizeListSize(long pointer);

public static native boolean isDate(long pointer);

public static native boolean isTime(long pointer);
Expand Down
75 changes: 75 additions & 0 deletions java/vortex-jni/src/test/java/dev/vortex/api/DTypeTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

package dev.vortex.api;

import static org.junit.jupiter.api.Assertions.*;

import org.junit.jupiter.api.Test;

public final class DTypeTest {

@Test
public void testNewFixedSizeListNonNullable() {
var elementType = DType.newInt(false);
var fslType = DType.newFixedSizeList(elementType, 3, false);
assertEquals(DType.Variant.FIXED_SIZE_LIST, fslType.getVariant());
assertFalse(fslType.isNullable());
assertEquals(3, fslType.getFixedSizeListSize());

var innerType = fslType.getElementType();
assertEquals(DType.Variant.PRIMITIVE_I32, innerType.getVariant());
}

@Test
public void testNewFixedSizeListNullable() {
var elementType = DType.newUtf8(true);
var fslType = DType.newFixedSizeList(elementType, 5, true);
assertEquals(DType.Variant.FIXED_SIZE_LIST, fslType.getVariant());
assertTrue(fslType.isNullable());
assertEquals(5, fslType.getFixedSizeListSize());

var innerType = fslType.getElementType();
assertEquals(DType.Variant.UTF8, innerType.getVariant());
}

@Test
public void testNewListGetElementType() {
var elementType = DType.newDouble(false);
var listType = DType.newList(elementType, false);
assertEquals(DType.Variant.LIST, listType.getVariant());

var innerType = listType.getElementType();
assertEquals(DType.Variant.PRIMITIVE_F64, innerType.getVariant());
}

@Test
public void testNestedFixedSizeList() {
var innerElement = DType.newLong(false);
var innerFsl = DType.newFixedSizeList(innerElement, 2, false);
var outerFsl = DType.newFixedSizeList(innerFsl, 4, true);
assertEquals(DType.Variant.FIXED_SIZE_LIST, outerFsl.getVariant());
assertTrue(outerFsl.isNullable());
assertEquals(4, outerFsl.getFixedSizeListSize());

var inner = outerFsl.getElementType();
assertEquals(DType.Variant.FIXED_SIZE_LIST, inner.getVariant());
}

@Test
public void testFixedSizeListInStruct() {
var elementType = DType.newFloat(false);
var fslType = DType.newFixedSizeList(elementType, 3, false);
var structType = DType.newStruct(
new String[] {"id", "embedding"},
new DType[] {DType.newInt(false), fslType},
false);
assertEquals(DType.Variant.STRUCT, structType.getVariant());

var fieldTypes = structType.getFieldTypes();
assertEquals(2, fieldTypes.size());

var embeddingType = fieldTypes.get(1);
assertEquals(DType.Variant.FIXED_SIZE_LIST, embeddingType.getVariant());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ public static DataType toDataType(DType dType) {

return DataTypes.createStructType(fields);
case LIST:
case FIXED_SIZE_LIST:
return DataTypes.createArrayType(toDataType(dType.getElementType()), dType.isNullable());
case EXTENSION:
/*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

package dev.vortex.spark;

import static org.junit.jupiter.api.Assertions.*;

import dev.vortex.api.DType;
import dev.vortex.jni.NativeLoader;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataTypes;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;

public final class SparkTypesTest {

@BeforeAll
public static void loadLibrary() {
NativeLoader.loadJni();
}

@Test
@DisplayName("toDataType should convert FIXED_SIZE_LIST to Spark ArrayType")
public void testFixedSizeListToDataType() {
var elementType = DType.newInt(false);
var fslType = DType.newFixedSizeList(elementType, 3, true);
var sparkType = SparkTypes.toDataType(fslType);
assertInstanceOf(ArrayType.class, sparkType);
ArrayType arrayType = (ArrayType) sparkType;
assertEquals(DataTypes.IntegerType, arrayType.elementType());
}

@Test
@DisplayName("toDataType should convert LIST to Spark ArrayType")
public void testListToDataType() {
var elementType = DType.newDouble(false);
var listType = DType.newList(elementType, true);
var sparkType = SparkTypes.toDataType(listType);
assertInstanceOf(ArrayType.class, sparkType);
ArrayType arrayType = (ArrayType) sparkType;
assertEquals(DataTypes.DoubleType, arrayType.elementType());
}
}
7 changes: 6 additions & 1 deletion vortex-jni/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,12 @@ fn data_type_no_views(data_type: DataType) -> DataType {
}
DataType::Decimal128(precision, scale) => DataType::Decimal128(precision, scale),
DataType::Decimal256(precision, scale) => DataType::Decimal256(precision, scale),
DataType::FixedSizeList(..) => unreachable!("Vortex never returns FixedSizeList"),
DataType::FixedSizeList(inner, size) => {
let new_inner = (*inner)
.clone()
.with_data_type(data_type_no_views(inner.data_type().clone()));
DataType::FixedSizeList(FieldRef::new(new_inner), size)
}
DataType::Union(..) => unreachable!("Vortex never returns Union"),
DataType::Dictionary(..) => unreachable!("Vortex never returns Dictionary"),
DataType::Map(..) => unreachable!("Vortex never returns Map"),
Expand Down
47 changes: 42 additions & 5 deletions vortex-jni/src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pub const DTYPE_STRUCT: jbyte = 15;
pub const DTYPE_LIST: jbyte = 16;
pub const DTYPE_EXTENSION: jbyte = 17;
pub const DTYPE_DECIMAL: jbyte = 18;
pub const DTYPE_FIXED_SIZE_LIST: jbyte = 19;

static LONG_CLASS: &str = "java/lang/Long";

Expand Down Expand Up @@ -94,9 +95,7 @@ pub extern "system" fn Java_dev_vortex_jni_NativeDTypeMethods_getVariant(
DType::Binary(_) => DTYPE_BINARY,
DType::Struct(..) => DTYPE_STRUCT,
DType::List(..) => DTYPE_LIST,
DType::FixedSizeList(..) => {
unimplemented!("TODO(connor)[FixedSizeList]")
}
DType::FixedSizeList(..) => DTYPE_FIXED_SIZE_LIST,
DType::Extension(_) => DTYPE_EXTENSION,
DType::Variant(_) => unimplemented!("Variant DType is not supported in JNI yet"),
}
Expand Down Expand Up @@ -184,8 +183,11 @@ pub extern "system" fn Java_dev_vortex_jni_NativeDTypeMethods_getElementType(
let dtype = unsafe { &*(dtype_ptr as *const DType) };

try_or_throw(&mut env, |_| {
let Some(element_type) = dtype.as_list_element_opt() else {
throw_runtime!("DType should be LIST, was {dtype}");
let element_type = dtype
.as_list_element_opt()
.or_else(|| dtype.as_fixed_size_list_element_opt());
let Some(element_type) = element_type else {
throw_runtime!("DType should be LIST or FIXED_SIZE_LIST, was {dtype}");
};

Ok(element_type.as_ref() as *const DType as jlong)
Expand Down Expand Up @@ -506,6 +508,41 @@ pub extern "system" fn Java_dev_vortex_jni_NativeDTypeMethods_newList(
Box::into_raw(Box::new(list_type)) as jlong
}

/// FixedSizeList constructor
#[unsafe(no_mangle)]
pub extern "system" fn Java_dev_vortex_jni_NativeDTypeMethods_newFixedSizeList(
_env: JNIEnv,
_class: JClass,
element_ptr: jlong,
size: jint,
is_nullable: jboolean,
) -> jlong {
let element_dtype = unsafe { *Box::from_raw(element_ptr as *mut DType) };
let element_dtype = Arc::new(element_dtype);

let fsl_type = DType::FixedSizeList(element_dtype, size as u32, to_nullability(is_nullable));

Box::into_raw(Box::new(fsl_type)) as jlong
}

/// Get the fixed size of a FixedSizeList DType.
#[unsafe(no_mangle)]
pub extern "system" fn Java_dev_vortex_jni_NativeDTypeMethods_getFixedSizeListSize(
mut env: JNIEnv,
_class: JClass,
dtype_ptr: jlong,
) -> jint {
let dtype = unsafe { &*(dtype_ptr as *const DType) };

try_or_throw(&mut env, |_| {
let DType::FixedSizeList(_, size, _) = dtype else {
throw_runtime!("DType should be FIXED_SIZE_LIST, was {dtype}");
};

Ok(*size as jint)
})
}

/// Struct constructor
#[unsafe(no_mangle)]
pub extern "system" fn Java_dev_vortex_jni_NativeDTypeMethods_newStruct<'local>(
Expand Down
Loading