Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,23 @@ object PaimonTableValuedFunctions {
VECTOR_SEARCH,
FULL_TEXT_SEARCH)

def parsePositiveLimit(value: Any): Int = {
val limit = value match {
case i: Int => i
case l: Long if l <= Int.MaxValue => l.toInt
case l: Long =>
throw new IllegalArgumentException(
s"Limit must be no greater than ${Int.MaxValue}, but got: $l")
case other => throw new RuntimeException(s"Invalid limit type: ${other.getClass.getName}")
}
if (limit <= 0) {
throw new IllegalArgumentException(
s"Limit must be a positive integer, but got: $limit"
)
}
limit
}

private type TableFunctionDescription = (FunctionIdentifier, ExpressionInfo, TableFunctionBuilder)

def getTableValueFunctionInjection(fnName: String): TableFunctionDescription = {
Expand Down Expand Up @@ -307,16 +324,7 @@ case class VectorSearchQuery(override val args: Seq[Expression])
)
}
val queryVector = extractQueryVector(argsWithoutTable(1))
val limit = argsWithoutTable(2).eval() match {
case i: Int => i
case l: Long => l.toInt
case other => throw new RuntimeException(s"Invalid limit type: ${other.getClass.getName}")
}
if (limit <= 0) {
throw new IllegalArgumentException(
s"Limit must be a positive integer, but got: $limit"
)
}
val limit = parsePositiveLimit(argsWithoutTable(2).eval())
new VectorSearch(queryVector, limit, columnName)
}

Expand Down Expand Up @@ -374,16 +382,7 @@ case class FullTextSearchQuery(override val args: Seq[Expression])
)
}
val queryText = argsWithoutTable(1).eval().toString
val limit = argsWithoutTable(2).eval() match {
case i: Int => i
case l: Long => l.toInt
case other => throw new RuntimeException(s"Invalid limit type: ${other.getClass.getName}")
}
if (limit <= 0) {
throw new IllegalArgumentException(
s"Limit must be a positive integer, but got: $limit"
)
}
val limit = parsePositiveLimit(argsWithoutTable(2).eval())
new FullTextSearch(queryText, limit, columnName)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package org.apache.paimon.spark.sql
import org.apache.paimon.data.{BinaryString, GenericRow, Timestamp}
import org.apache.paimon.manifest.ManifestCommittable
import org.apache.paimon.spark.PaimonHiveTestBase
import org.apache.paimon.spark.catalyst.plans.logical.PaimonTableValuedFunctions
import org.apache.paimon.utils.DateTimeUtils

import org.apache.spark.sql.{DataFrame, Row}
Expand All @@ -30,6 +31,16 @@ import java.util.Collections

class TableValuedFunctionsTest extends PaimonHiveTestBase {

test("parse positive limit rejects overflowing long") {
val longValue: Long = 4294967297L
assert(longValue.toInt > 0)

val error = intercept[IllegalArgumentException] {
PaimonTableValuedFunctions.parsePositiveLimit(longValue)
}
assert(error.getMessage.contains("Limit must be no greater than"))
}

withPk.foreach {
hasPk =>
bucketModes.foreach {
Expand Down