Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@ public static Variant deleteAtPath(Variant v, PathSegment[] segments) {
return builder.result();
}

// Return a new variant with null-valued object fields removed, recursing into nested objects
// and arrays. When `includeArrays` is true, null array elements are removed too; when false,
// arrays keep their nulls but objects inside them are still cleaned. Empty containers and a
// top-level variant null are left unchanged. The result is always rebuilt with fresh metadata.
public static Variant stripNulls(Variant v, boolean includeArrays) {
VariantBuilder builder = new VariantBuilder(false);
builder.appendWithNullStrippingImpl(v.value, v.metadata, v.pos, includeArrays);
return builder.result();
}

// Build the variant metadata from `dictionaryKeys` and return the variant result.
public Variant result() {
int numKeys = dictionaryKeys.size();
Expand Down Expand Up @@ -549,6 +559,54 @@ private void appendWithDeletionImpl(
}
}

private void appendWithNullStrippingImpl(
byte[] value, byte[] metadata, int pos, boolean includeArrays) {
checkIndex(pos, value.length);
int basicType = value[pos] & BASIC_TYPE_MASK;
if (basicType == OBJECT) {
handleObject(value, pos, (size, idSize, offsetSize, idStart, offsetStart, dataStart) -> {
ArrayList<FieldEntry> fields = new ArrayList<>(size);
int start = writePos;
for (int i = 0; i < size; ++i) {
int id = readUnsigned(value, idStart + idSize * i, idSize);
int offset = readUnsigned(value, offsetStart + offsetSize * i, offsetSize);
int elementPos = dataStart + offset;
// Drop the whole field when its value is a variant null.
if (getType(value, elementPos) == Type.NULL) {
continue;
}
String key = getMetadataKey(metadata, id);
int newId = addKey(key);
fields.add(new FieldEntry(key, newId, writePos - start));
appendWithNullStrippingImpl(value, metadata, elementPos, includeArrays);
}
finishWritingObject(start, fields);
return null;
});
} else if (basicType == ARRAY) {
handleArray(value, pos, (size, offsetSize, offsetStart, dataStart) -> {
ArrayList<Integer> offsets = new ArrayList<>(size);
int start = writePos;
for (int i = 0; i < size; ++i) {
int offset = readUnsigned(value, offsetStart + offsetSize * i, offsetSize);
int elementPos = dataStart + offset;
// Drop variant-null elements only when stripping arrays; otherwise keep them but still
// recurse into nested containers.
if (includeArrays && getType(value, elementPos) == Type.NULL) {
continue;
}
offsets.add(writePos - start);
appendWithNullStrippingImpl(value, metadata, elementPos, includeArrays);
}
finishWritingArray(start, offsets);
return null;
});
} else {
// Scalars and standalone variant nulls are appended unchanged.
appendVariantImpl(value, metadata, pos);
}
}

// Append the variant value without rewriting or creating any metadata. This is used when
// building an object during shredding, where there is a fixed pre-existing metadata that
// all shredded values will refer to.
Expand Down
1 change: 1 addition & 0 deletions python/docs/source/reference/pyspark.sql/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,7 @@ VARIANT Functions
schema_of_variant_agg
try_variant_get
variant_delete
variant_strip_nulls
variant_get
try_parse_json
to_variant_object
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2219,6 +2219,13 @@ def variant_delete(v: "ColumnOrName", *paths: Union[Column, str]) -> Column:
variant_delete.__doc__ = pysparkfuncs.variant_delete.__doc__


def variant_strip_nulls(v: "ColumnOrName", include_arrays: bool = True) -> Column:
return _invoke_function("variant_strip_nulls", _to_col(v), lit(include_arrays))


variant_strip_nulls.__doc__ = pysparkfuncs.variant_strip_nulls.__doc__


def variant_get(v: "ColumnOrName", path: Union[Column, str], targetType: str) -> Column:
assert isinstance(path, (Column, str))
if isinstance(path, str):
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/sql/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@
"schema_of_variant_agg",
"try_variant_get",
"variant_delete",
"variant_strip_nulls",
"variant_get",
"try_parse_json",
"to_variant_object",
Expand Down
42 changes: 42 additions & 0 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21626,6 +21626,48 @@ def variant_delete(v: "ColumnOrName", *paths: Union[Column, str]) -> Column:
)


@_try_remote_functions
def variant_strip_nulls(v: "ColumnOrName", include_arrays: bool = True) -> Column:
"""
Recursively removes null fields from variant objects, and null elements from arrays unless
`include_arrays` is False. Returns NULL if `v` is NULL.

.. versionadded:: 4.3.0

Parameters
----------
v : :class:`~pyspark.sql.Column` or str
a variant column or column name
include_arrays : bool, optional
whether null elements are also removed from arrays. If False, array null elements are kept
while null fields of nested objects are still removed. Defaults to True.

Returns
-------
:class:`~pyspark.sql.Column`
a variant column with variant null fields/elements removed

Examples
--------
>>> from pyspark.sql.functions import lit, parse_json, to_json, variant_strip_nulls
>>> df = spark.createDataFrame([{
... 'json': '''{ "a" : 1, "b" : null, "c" : [1, null], "d" : { "e" : null, "f" : 4 } }'''
... }])
>>> v = parse_json(df.json)
>>> df.select(to_json(variant_strip_nulls(v)).alias("r")).collect()
[Row(r='{"a":1,"c":[1],"d":{"f":4}}')]
>>> df.select(to_json(variant_strip_nulls(v, False)).alias("r")).collect()
[Row(r='{"a":1,"c":[1,null],"d":{"f":4}}')]
>>> df.select(variant_strip_nulls(lit(None)).alias("r")).collect()
[Row(r=None)]
"""
from pyspark.sql.classic.column import _to_java_column

return _invoke_function(
"variant_strip_nulls", _to_java_column(v), _enum_to_value(include_arrays)
)


@_try_remote_functions
def variant_get(v: "ColumnOrName", path: Union[Column, str], targetType: str) -> Column:
"""
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3524,6 +3524,8 @@ def check(resultDf, expected):
df.select(F.to_json(F.variant_delete(v, F.lit(None)))),
['{"a":1}', '{"b":2}'],
)
check(df.select(F.to_json(F.variant_strip_nulls(v))), ['{"a":1}', '{"b":2}'])
check(df.select(F.to_json(F.variant_strip_nulls(v, False))), ['{"a":1}', '{"b":2}'])
check(df.select(F.schema_of_variant(v)), ["OBJECT<a: BIGINT>", "OBJECT<b: BIGINT>"])
check(df.select(F.schema_of_variant_agg(v)), ["OBJECT<a: BIGINT, b: BIGINT>"])

Expand Down
26 changes: 26 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9745,6 +9745,32 @@ object functions {
def variant_delete(v: Column, path: String, paths: String*): Column =
Column.fn("variant_delete", (v +: lit(path) +: paths.map(lit)): _*)

/**
* Recursively removes null fields from variant objects, and null elements from arrays. Returns
* NULL if `v` is NULL.
*
* @param v
* a variant column.
* @group variant_funcs
* @since 4.3.0
*/
def variant_strip_nulls(v: Column): Column = Column.fn("variant_strip_nulls", v)

/**
* Recursively removes null fields from variant objects, and null elements from arrays unless
* `includeArrays` is false. Returns NULL if any argument is NULL.
*
* @param v
* a variant column.
* @param includeArrays
* whether null elements are also removed from arrays. If false, array null elements are kept,
* but null fields of nested objects are still removed.
* @group variant_funcs
* @since 4.3.0
*/
def variant_strip_nulls(v: Column, includeArrays: Boolean): Column =
Column.fn("variant_strip_nulls", v, lit(includeArrays))

/**
* Extracts a sub-variant from `v` according to `path` string, and then cast the sub-variant to
* `targetType`. Returns null if the path does not exist. Throws an exception if the cast fails.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,7 @@ object FunctionRegistry {
expression[ToVariantObject]("to_variant_object"),
expression[IsValidVariant]("is_valid_variant"),
expression[VariantDelete]("variant_delete"),
expression[VariantStripNulls]("variant_strip_nulls"),

// Spatial
expression[ST_AsBinary]("st_asbinary"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ object VariantExpressionEvalUtils {
def deleteAtPath(input: VariantVal, path: UTF8String): VariantVal =
deleteAtPath(input, toJavaSegments(parseVariantDeletePath(path.toString)))

def stripNulls(input: VariantVal, includeArrays: Boolean): VariantVal = {
val v = new Variant(input.getValue, input.getMetadata)
val out = VariantBuilder.stripNulls(v, includeArrays)
new VariantVal(out.getValue, out.getMetadata)
}

/** Cast a Spark value from `dataType` into the variant type. */
def castToVariant(input: Any, dataType: DataType): VariantVal = {
// Enforce strict check because it is illegal for input struct/map/variant to contain duplicate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.json.JsonInferSchema
import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter}
import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike}
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, VARIANT_GET}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, QuotingUtils}
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors}
Expand Down Expand Up @@ -804,6 +804,62 @@ object VariantDelete {
}
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(v[, includeArrays]) - Recursively removes null fields from variant objects, " +
"and null elements from arrays unless `includeArrays` is false. " +
"Returns NULL if any argument is NULL.",
arguments = """
Arguments:
* v - A variant value to mutate.
* includeArrays - An optional boolean (default true). If false, null array elements are
kept, but null fields of nested objects are still removed.
""",
examples = """
Examples:
> SELECT _FUNC_(parse_json('{"a": 1, "b": null, "c": 3}'));
{"a":1,"c":3}
> SELECT _FUNC_(parse_json('[1, null, 3]'));
[1,3]
> SELECT _FUNC_(parse_json('{"a": {"b": null, "c": [1, null]}}'));
{"a":{"c":[1]}}
> SELECT _FUNC_(parse_json('{"a": [1, null], "b": null}'), false);
{"a":[1,null]}
> SELECT _FUNC_(NULL);
NULL
""",
since = "4.3.0",
group = "variant_funcs"
)
// scalastyle:on line.size.limit
case class VariantStripNulls(child: Expression, includeArrays: Expression)
extends RuntimeReplaceable
with ExpectsInputTypes
with BinaryLike[Expression] {

def this(child: Expression) = this(child, Literal(true))

override def left: Expression = child
override def right: Expression = includeArrays

override def dataType: DataType = VariantType
override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, BooleanType)

override lazy val replacement: Expression = StaticInvoke(
VariantExpressionEvalUtils.getClass,
VariantType,
"stripNulls",
Seq(child, includeArrays),
inputTypes,
returnNullable = false)

override def prettyName: String = "variant_strip_nulls"

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): VariantStripNulls =
copy(child = newLeft, includeArrays = newRight)
}

case class VariantExplode(child: Expression) extends UnaryExpression with Generator
with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(VariantType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.types.variant.VariantBuilder
import org.apache.spark.types.variant.{Variant, VariantBuilder}
import org.apache.spark.types.variant.VariantUtil._
import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
import org.apache.spark.util.collection.Utils.createArray
Expand Down Expand Up @@ -1311,4 +1311,59 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
noPaths.checkInputDataTypes()
}
}

test("variant_strip_nulls") {
// Strip `input`, render the result back to JSON, and compare. `includeArrays` defaults to true.
def check(input: String, expected: String, includeArrays: Boolean = true): Unit = {
val expr = VariantStripNulls(Literal(parseJson(input)), Literal(includeArrays))
val result = replace(expr).eval().asInstanceOf[VariantVal]
val json = if (result == null) null
else new Variant(result.getValue, result.getMetadata).toJson(ZoneOffset.UTC)
assert(json == expected)
}

// The single-argument constructor defaults `includeArrays` to true.
assert(new VariantStripNulls(Literal(parseJson("[1, null]"))).includeArrays == Literal(true))

check("""{"a": 1, "b": null, "c": 3}""", """{"a":1,"c":3}""")
check("[1, null, 3]", "[1,3]")
check("""{"user": {"name": "Alice", "age": null}}""", """{"user":{"name":"Alice"}}""")
check("""{"a": [1, null, {"b": null, "c": 2}]}""", """{"a":[1,{"c":2}]}""")
check("[[1, null], [null]]", "[[1],[]]")

// Empty containers are preserved; the parent is never collapsed.
check("""{"a": null}""", "{}")
check("[null]", "[]")
check("""{"a": {"b": null}}""", """{"a":{}}""")
check("""{"a": [null]}""", """{"a":[]}""")
check("{}", "{}")
check("[]", "[]")

// Top-level variant null and scalars are returned unchanged.
check("null", "null")
check("42", "42")
check("\"hi\"", "\"hi\"")

check("""{"a": {"b": {"c": null, "d": 4}}}""", """{"a":{"b":{"d":4}}}""")

// `includeArrays = false`.
check("""{"a": [1, null, 3], "b": null}""", """{"a":[1,null,3]}""", includeArrays = false)
check(
"""[{"a": 1, "b": null}, null, {"c": null, "d": 4}]""",
"""[{"a":1},null,{"d":4}]""",
includeArrays = false)
// `includeArrays = true` (explicit) strips array nulls.
check("""{"a": [1, null, 3]}""", """{"a":[1,3]}""", includeArrays = true)

// SQL NULL variant input yields SQL NULL.
checkEvaluation(
Cast(new VariantStripNulls(Literal.create(null, VariantType)), StringType), null)
// NULL `includeArrays` yields SQL NULL (the expression is null intolerant).
checkEvaluation(
Cast(
VariantStripNulls(
Literal(parseJson("""{"a": null}""")), Literal.create(null, BooleanType)),
StringType),
null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2775,6 +2775,10 @@ class PlanGenerationTestSuite extends ConnectFunSuite with Logging {
fn.variant_delete(fn.parse_json(fn.col("g")), "$.a", "$.b")
}

functionTest("variant_strip_nulls") {
fn.variant_strip_nulls(fn.parse_json(fn.col("g")), false)
}

functionTest("variant_get") {
fn.variant_get(fn.parse_json(fn.col("g")), "$", "int")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [static_invoke(VariantExpressionEvalUtils.stripNulls(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, false, true, true)), false)) AS variant_strip_nulls(parse_json(g), false)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Loading