Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.json

import java.io.{ByteArrayOutputStream, CharArrayWriter, StringWriter}

import scala.collection.mutable
import scala.util.parsing.combinator.RegexParsers

import com.fasterxml.jackson.core._
Expand Down Expand Up @@ -575,24 +576,30 @@ case class GetJsonObjectEvaluator(cachedPath: UTF8String) {
}

/**
* Evaluates multiple simple top-level JSON fields in one parse.
* Evaluates multiple simple named JSON paths in one parse.
*/
case class MultiGetJsonObjectEvaluator(
fieldNames: Seq[String],
fallbackPaths: Seq[UTF8String]) {
fallbackPaths: Seq[UTF8String],
namedPaths: Seq[Seq[String]]) {
import SharedFactory._

require(
fieldNames.nonEmpty &&
fieldNames.distinct.length == fieldNames.length &&
fallbackPaths.length == fieldNames.length)
require(fallbackPaths.nonEmpty && namedPaths.length == fallbackPaths.length)

@transient
private lazy val fieldToOrdinal: Map[String, Int] = fieldNames.zipWithIndex.toMap
private lazy val useTopLevelFastPath: Boolean =
namedPaths.forall(_.length == 1) && namedPaths.distinct.length == namedPaths.length

@transient
private lazy val topLevelFieldToOrdinal: Map[String, Int] =
namedPaths.zipWithIndex.map { case (path, ordinal) => path.head -> ordinal }.toMap

@transient
private lazy val pathTrie: MultiGetJsonObjectEvaluator.PathTrieNode =
MultiGetJsonObjectEvaluator.buildPathTrie(namedPaths)

@transient
private lazy val nullRow: InternalRow =
new GenericInternalRow(Array.ofDim[Any](fieldNames.length))
new GenericInternalRow(Array.ofDim[Any](fallbackPaths.length))

@transient
private lazy val fallbackEvaluators: Seq[GetJsonObjectEvaluator] =
Expand All @@ -611,34 +618,18 @@ case class MultiGetJsonObjectEvaluator(
def evaluate(json: UTF8String): InternalRow = {
if (json == null) return null

val values = Array.ofDim[Any](fieldNames.length)
val matched = Array.ofDim[Boolean](fieldNames.length)
val values = Array.ofDim[Any](fallbackPaths.length)
val matched = Array.ofDim[Boolean](fallbackPaths.length)

try {
val validObject = Utils.tryWithResource(
CreateJacksonParser.utf8String(jsonFactory, json)) { parser =>
if (parser.nextToken() != JsonToken.START_OBJECT) {
false
} else if (useTopLevelFastPath) {
extractTopLevelObject(parser, values, matched)
} else {
var token = parser.nextToken()
while (token != null && token != JsonToken.END_OBJECT) {
if (token == JsonToken.FIELD_NAME) {
val fieldName = parser.currentName
val ordinal = fieldToOrdinal.get(fieldName).filter(!matched(_))
val valueToken = parser.nextToken()
if (ordinal.nonEmpty && valueToken != JsonToken.VALUE_NULL) {
val index = ordinal.get
matched(index) = true
copyCurrentStructure(parser).foreach(value => values(index) = value)
} else {
parser.skipChildren()
}
} else {
parser.skipChildren()
}
token = parser.nextToken()
}
token == JsonToken.END_OBJECT
extractObject(parser, pathTrie, values, matched)
}
}
if (validObject) {
Expand All @@ -647,15 +638,88 @@ case class MultiGetJsonObjectEvaluator(
nullRow
}
} catch {
// Every simple top-level legacy extraction scans through the root object's closing token,
// so a syntax failure makes every sibling null without needing per-path reparsing.
// Every simple named legacy extraction scans through the root object's closing token, so a
// syntax failure makes every sibling null without needing per-path reparsing.
case _: JsonParseException => nullRow
// A parser-side rendering failure can leave the shared token stream unusable. Reparse each
// path with the legacy evaluator so one bad selected value cannot erase sibling results.
// A parser-side rendering failure, such as a string-length constraint violation, can leave
// the shared token stream unusable. Reparse each path with the legacy evaluator so one bad
// selected value cannot erase independent sibling results.
case _: JsonProcessingException => fallback(json)
}
}

private def extractTopLevelObject(
parser: JsonParser,
values: Array[Any],
matched: Array[Boolean]): Boolean = {
var token = parser.nextToken()
while (token != null && token != JsonToken.END_OBJECT) {
if (token == JsonToken.FIELD_NAME) {
val ordinal = topLevelFieldToOrdinal.get(parser.currentName).filter(!matched(_))
val valueToken = parser.nextToken()
if (ordinal.nonEmpty && valueToken != JsonToken.VALUE_NULL) {
val index = ordinal.get
matched(index) = true
copyCurrentStructure(parser).foreach(value => values(index) = value)
} else {
parser.skipChildren()
}
} else {
parser.skipChildren()
}
token = parser.nextToken()
}
token == JsonToken.END_OBJECT
}

private def extractObject(
parser: JsonParser,
node: MultiGetJsonObjectEvaluator.PathTrieNode,
values: Array[Any],
matched: Array[Boolean]): Boolean = {
var valid = true
var token = parser.nextToken()
while (valid && token != null && token != JsonToken.END_OBJECT) {
if (token == JsonToken.FIELD_NAME) {
val child = node.children.get(parser.currentName).filter(_.hasUnmatched(matched))
val valueToken = parser.nextToken()
if (child.nonEmpty && valueToken != JsonToken.VALUE_NULL) {
valid = extractValue(parser, child.get, values, matched)
} else {
parser.skipChildren()
}
} else {
parser.skipChildren()
}
if (valid) {
token = parser.nextToken()
}
}
valid && token == JsonToken.END_OBJECT
}

private def extractValue(
parser: JsonParser,
node: MultiGetJsonObjectEvaluator.PathTrieNode,
values: Array[Any],
matched: Array[Boolean]): Boolean = {
// Optimizer-generated paths are deduplicated. Multiple ordinals defensively support
// directly constructed internal expressions with duplicate paths.
if (node.terminalOrdinals.nonEmpty) {
Comment thread
sunchao marked this conversation as resolved.
node.terminalOrdinals.foreach { ordinal => matched(ordinal) = true }
val value = copyCurrentStructure(parser)
value.foreach { result =>
node.terminalOrdinals.foreach { ordinal => values(ordinal) = result }
}
true
} else if (parser.currentToken == JsonToken.START_OBJECT) {
extractObject(parser, node, values, matched)
} else {
parser.skipChildren()
true
}
}

private def copyCurrentStructure(parser: JsonParser): Option[UTF8String] = {
outputBuffer.reset()
var renderingFailed = false
Expand Down Expand Up @@ -726,3 +790,43 @@ case class MultiGetJsonObjectEvaluator(
if (renderingFailed) None else Some(UTF8String.fromBytes(outputBuffer.toByteArray))
}
}

object MultiGetJsonObjectEvaluator {
private final class MutablePathTrieNode {
val terminalOrdinals: mutable.ArrayBuffer[Int] = mutable.ArrayBuffer.empty
val children: mutable.LinkedHashMap[String, MutablePathTrieNode] = mutable.LinkedHashMap.empty

def freeze(): PathTrieNode = {
require(
terminalOrdinals.isEmpty || children.isEmpty,
"Shared JSON paths must not be prefixes of one another")
val frozenChildren = children.iterator.map { case (name, child) =>
name -> child.freeze()
}.toMap
val ordinals = (terminalOrdinals.iterator ++
frozenChildren.valuesIterator.flatMap(_.descendantOrdinals.iterator)).toArray
PathTrieNode(terminalOrdinals.toArray, frozenChildren, ordinals)
}
}

private case class PathTrieNode(
terminalOrdinals: Array[Int],
children: Map[String, PathTrieNode],
descendantOrdinals: Array[Int]) {
def hasUnmatched(matched: Array[Boolean]): Boolean = {
descendantOrdinals.exists(index => !matched(index))
}
}

private def buildPathTrie(paths: Seq[Seq[String]]): PathTrieNode = {
val root = new MutablePathTrieNode
paths.zipWithIndex.foreach { case (path, ordinal) =>
var node = root
path.foreach { fieldName =>
node = node.children.getOrElseUpdate(fieldName, new MutablePathTrieNode)
}
node.terminalOrdinals += ordinal
}
root.freeze()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,16 @@ case class GetJsonObject(json: Expression, path: Expression)
}

object GetJsonObject {
private[sql] def simpleTopLevelField(path: UTF8String): Option[String] = {
import PathInstruction._

private[sql] def simpleNamedPath(path: UTF8String): Option[Seq[String]] = {
try {
Option(path).flatMap(value => JsonPathParser.parse(value.toString)).collect {
case List(PathInstruction.Key, PathInstruction.Named(fieldName)) => fieldName
Option(path).flatMap(value => JsonPathParser.parse(value.toString)).flatMap { instructions =>
val names = instructions.grouped(2).map {
case List(Key, Named(fieldName)) => Some(fieldName)
case _ => None
}.toSeq
if (names.nonEmpty && names.forall(_.isDefined)) Some(names.flatten) else None
}
} catch {
// Numeric subscripts are parsed as Long and can overflow before the parser returns None.
Expand All @@ -155,28 +161,25 @@ object GetJsonObject {
}

/**
* Extracts multiple simple top-level fields from a JSON string in one parse. This is an internal
* expression used to share sibling [[GetJsonObject]] expressions; unsupported JSON paths remain
* as independent GetJsonObject expressions.
* Extracts multiple simple named paths from a JSON string in one parse. This is an internal
* expression used to share sibling [[GetJsonObject]] expressions; unsupported and
* prefix-conflicting JSON paths remain as independent GetJsonObject expressions.
*/
case class MultiGetJsonObject(
json: Expression,
fieldNames: Seq[String],
fallbackPaths: Seq[String])
extends UnaryExpression
with ExpectsInputTypes {

require(
fieldNames.nonEmpty &&
fieldNames.distinct.length == fieldNames.length &&
fallbackPaths.length == fieldNames.length)
// OptimizeCsvJsonExprs caps shared path depth to keep evaluator recursion stack-safe.
require(fallbackPaths.nonEmpty)
Comment thread
sunchao marked this conversation as resolved.

override def child: Expression = json

override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation(supportsTrimCollation = true))

override lazy val dataType: DataType = StructType(fieldNames.indices.map { index =>
override lazy val dataType: DataType = StructType(fallbackPaths.indices.map { index =>
StructField(s"_$index", StringType, nullable = true)
})

Expand All @@ -189,10 +192,17 @@ case class MultiGetJsonObject(

final override val nodePatterns: Seq[TreePattern] = Seq(GET_JSON_OBJECT)

@transient
private lazy val namedPaths = fallbackPaths.map { path =>
Comment thread
sunchao marked this conversation as resolved.
GetJsonObject.simpleNamedPath(UTF8String.fromString(path)).getOrElse {
throw new IllegalArgumentException(s"Unsupported shared JSON path: $path")
}
}

@transient
private lazy val evaluator = MultiGetJsonObjectEvaluator(
fieldNames,
fallbackPaths.map(UTF8String.fromString))
fallbackPaths.map(UTF8String.fromString),
namedPaths)

override def eval(input: InternalRow): Any = {
evaluator.evaluate(json.eval(input).asInstanceOf[UTF8String])
Expand Down
Loading