From 923d299deb6743d8c58102bee8ccf4d7d3c484b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bojana=20Ze=C4=8Devi=C4=87?= Date: Mon, 29 Jun 2026 15:01:03 +0000 Subject: [PATCH 1/2] init --- .../spark/types/variant/VariantBuilder.java | 58 ++++++++++ .../reference/pyspark.sql/functions.rst | 1 + .../pyspark/sql/connect/functions/builtin.py | 10 ++ python/pyspark/sql/functions/__init__.py | 1 + python/pyspark/sql/functions/builtin.py | 47 ++++++++ python/pyspark/sql/tests/test_functions.py | 12 ++ .../org/apache/spark/sql/functions.scala | 26 +++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../variant/VariantExpressionEvalUtils.scala | 6 + .../variant/variantExpressions.scala | 58 +++++++++- .../variant/VariantExpressionSuite.scala | 57 +++++++++- .../spark/sql/PlanGenerationTestSuite.scala | 4 + .../function_variant_strip_nulls.explain | 2 + .../queries/function_variant_strip_nulls.json | 104 ++++++++++++++++++ .../function_variant_strip_nulls.proto.bin | Bin 0 -> 967 bytes .../sql-functions/sql-expression-schema.md | 1 + .../org/apache/spark/sql/VariantSuite.scala | 62 +++++++++++ 17 files changed, 448 insertions(+), 2 deletions(-) create mode 100644 sql/connect/common/src/test/resources/query-tests/explain-results/function_variant_strip_nulls.explain create mode 100644 sql/connect/common/src/test/resources/query-tests/queries/function_variant_strip_nulls.json create mode 100644 sql/connect/common/src/test/resources/query-tests/queries/function_variant_strip_nulls.proto.bin diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java index 79f869ba6be7c..ce3aab0773c2e 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java @@ -130,6 +130,16 @@ public static Variant deleteAtPath(Variant v, PathSegment[] segments) { return builder.result(); } + // Return a new variant with null-valued object fields removed, recursing into nested objects + // and arrays. When `includeArrays` is true, null array elements are removed too; when false, + // arrays keep their nulls but objects inside them are still cleaned. Empty containers and a + // top-level variant null are left unchanged. The result is always rebuilt with fresh metadata. + public static Variant stripNulls(Variant v, boolean includeArrays) { + VariantBuilder builder = new VariantBuilder(false); + builder.appendWithNullStrippingImpl(v.value, v.metadata, v.pos, includeArrays); + return builder.result(); + } + // Build the variant metadata from `dictionaryKeys` and return the variant result. public Variant result() { int numKeys = dictionaryKeys.size(); @@ -549,6 +559,54 @@ private void appendWithDeletionImpl( } } + private void appendWithNullStrippingImpl( + byte[] value, byte[] metadata, int pos, boolean includeArrays) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + if (basicType == OBJECT) { + handleObject(value, pos, (size, idSize, offsetSize, idStart, offsetStart, dataStart) -> { + ArrayList fields = new ArrayList<>(size); + int start = writePos; + for (int i = 0; i < size; ++i) { + int id = readUnsigned(value, idStart + idSize * i, idSize); + int offset = readUnsigned(value, offsetStart + offsetSize * i, offsetSize); + int elementPos = dataStart + offset; + // Drop the whole field when its value is a variant null. + if (getType(value, elementPos) == Type.NULL) { + continue; + } + String key = getMetadataKey(metadata, id); + int newId = addKey(key); + fields.add(new FieldEntry(key, newId, writePos - start)); + appendWithNullStrippingImpl(value, metadata, elementPos, includeArrays); + } + finishWritingObject(start, fields); + return null; + }); + } else if (basicType == ARRAY) { + handleArray(value, pos, (size, offsetSize, offsetStart, dataStart) -> { + ArrayList offsets = new ArrayList<>(size); + int start = writePos; + for (int i = 0; i < size; ++i) { + int offset = readUnsigned(value, offsetStart + offsetSize * i, offsetSize); + int elementPos = dataStart + offset; + // Drop variant-null elements only when stripping arrays; otherwise keep them but still + // recurse into nested containers. + if (includeArrays && getType(value, elementPos) == Type.NULL) { + continue; + } + offsets.add(writePos - start); + appendWithNullStrippingImpl(value, metadata, elementPos, includeArrays); + } + finishWritingArray(start, offsets); + return null; + }); + } else { + // Scalars and standalone variant nulls are appended unchanged. + appendVariantImpl(value, metadata, pos); + } + } + // Append the variant value without rewriting or creating any metadata. This is used when // building an object during shredding, where there is a fixed pre-existing metadata that // all shredded values will refer to. diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index 3ad3ae9cdf127..4a99d3842c498 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -602,6 +602,7 @@ VARIANT Functions schema_of_variant_agg try_variant_get variant_delete + variant_strip_nulls variant_get try_parse_json to_variant_object diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index 00183fe283f0a..b53e703d9145c 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -2219,6 +2219,16 @@ def variant_delete(v: "ColumnOrName", *paths: Union[Column, str]) -> Column: variant_delete.__doc__ = pysparkfuncs.variant_delete.__doc__ +def variant_strip_nulls(v: "ColumnOrName", include_arrays: Union[Column, bool] = True) -> Column: + include_arrays_col = ( + include_arrays if isinstance(include_arrays, Column) else lit(include_arrays) + ) + return _invoke_function("variant_strip_nulls", _to_col(v), include_arrays_col) + + +variant_strip_nulls.__doc__ = pysparkfuncs.variant_strip_nulls.__doc__ + + def variant_get(v: "ColumnOrName", path: Union[Column, str], targetType: str) -> Column: assert isinstance(path, (Column, str)) if isinstance(path, str): diff --git a/python/pyspark/sql/functions/__init__.py b/python/pyspark/sql/functions/__init__.py index 914b9c7fbcb79..b96954343a193 100644 --- a/python/pyspark/sql/functions/__init__.py +++ b/python/pyspark/sql/functions/__init__.py @@ -480,6 +480,7 @@ "schema_of_variant_agg", "try_variant_get", "variant_delete", + "variant_strip_nulls", "variant_get", "try_parse_json", "to_variant_object", diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index bca4704962f96..0858b941a01e5 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -21626,6 +21626,53 @@ def variant_delete(v: "ColumnOrName", *paths: Union[Column, str]) -> Column: ) +@_try_remote_functions +def variant_strip_nulls(v: "ColumnOrName", include_arrays: Union[Column, bool] = True) -> Column: + """ + Recursively removes null fields from variant objects, and null elements from arrays unless + `include_arrays` is False. Returns NULL if any argument is NULL. + + .. versionadded:: 4.3.0 + + Parameters + ---------- + v : :class:`~pyspark.sql.Column` or str + a variant column or column name + include_arrays : :class:`~pyspark.sql.Column` or bool, optional + whether null elements are also removed from arrays. If False, array null elements are kept + while null fields of nested objects are still removed. Defaults to True. + + Returns + ------- + :class:`~pyspark.sql.Column` + a variant column with variant null fields/elements removed + + Examples + -------- + >>> from pyspark.sql.functions import lit, parse_json, to_json, variant_strip_nulls + >>> df = spark.createDataFrame([{ + ... 'json': '''{ "a" : 1, "b" : null, "c" : [1, null], "d" : { "e" : null, "f" : 4 } }''' + ... }]) + >>> v = parse_json(df.json) + >>> df.select(to_json(variant_strip_nulls(v)).alias("r")).collect() + [Row(r='{"a":1,"c":[1],"d":{"f":4}}')] + >>> df.select(to_json(variant_strip_nulls(v, False)).alias("r")).collect() + [Row(r='{"a":1,"c":[1,null],"d":{"f":4}}')] + >>> df.select(variant_strip_nulls(lit(None)).alias("r")).collect() + [Row(r=None)] + """ + from pyspark.sql.classic.column import _to_java_column + + include_arrays_col = ( + include_arrays if isinstance(include_arrays, Column) else lit(include_arrays) + ) + return _invoke_function( + "variant_strip_nulls", + _to_java_column(v), + _to_java_column(include_arrays_col), + ) + + @_try_remote_functions def variant_get(v: "ColumnOrName", path: Union[Column, str], targetType: str) -> Column: """ diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 16928193db6da..70162956e0cc9 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -3524,6 +3524,18 @@ def check(resultDf, expected): df.select(F.to_json(F.variant_delete(v, F.lit(None)))), ['{"a":1}', '{"b":2}'], ) + check(df.select(F.to_json(F.variant_strip_nulls(v))), ['{"a":1}', '{"b":2}']) + check( + df.select(F.to_json(F.variant_strip_nulls(F.parse_json(F.lit('{"a": 1, "b": null}'))))), + ['{"a":1}', '{"a":1}'], + ) + inc = df.path == "$.a" + check( + df.select( + F.to_json(F.variant_strip_nulls(F.parse_json(F.lit('{"x": [1, null]}')), inc)) + ), + ['{"x":[1]}', '{"x":[1,null]}'], + ) check(df.select(F.schema_of_variant(v)), ["OBJECT", "OBJECT"]) check(df.select(F.schema_of_variant_agg(v)), ["OBJECT"]) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index 86a495088aadc..ef10ea356b516 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -9745,6 +9745,32 @@ object functions { def variant_delete(v: Column, path: String, paths: String*): Column = Column.fn("variant_delete", (v +: lit(path) +: paths.map(lit)): _*) + /** + * Recursively removes null fields from variant objects, and null elements from arrays. Returns + * NULL if `v` is NULL. + * + * @param v + * a variant column. + * @group variant_funcs + * @since 4.3.0 + */ + def variant_strip_nulls(v: Column): Column = Column.fn("variant_strip_nulls", v) + + /** + * Recursively removes null fields from variant objects, and null elements from arrays unless + * `includeArrays` is false. Returns NULL if any argument is NULL. + * + * @param v + * a variant column. + * @param includeArrays + * whether null elements are also removed from arrays. If false, array null elements are kept, + * but null fields of nested objects are still removed. + * @group variant_funcs + * @since 4.3.0 + */ + def variant_strip_nulls(v: Column, includeArrays: Boolean): Column = + Column.fn("variant_strip_nulls", v, lit(includeArrays)) + /** * Extracts a sub-variant from `v` according to `path` string, and then cast the sub-variant to * `targetType`. Returns null if the path does not exist. Throws an exception if the cast fails. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 737a9da8b2b71..f565ebe2925a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -986,6 +986,7 @@ object FunctionRegistry { expression[ToVariantObject]("to_variant_object"), expression[IsValidVariant]("is_valid_variant"), expression[VariantDelete]("variant_delete"), + expression[VariantStripNulls]("variant_strip_nulls"), // Spatial expression[ST_AsBinary]("st_asbinary"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala index 47afbcd78837e..84c38526a0801 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala @@ -105,6 +105,12 @@ object VariantExpressionEvalUtils { def deleteAtPath(input: VariantVal, path: UTF8String): VariantVal = deleteAtPath(input, toJavaSegments(parseVariantDeletePath(path.toString))) + def stripNulls(input: VariantVal, includeArrays: Boolean): VariantVal = { + val v = new Variant(input.getValue, input.getMetadata) + val out = VariantBuilder.stripNulls(v, includeArrays) + new VariantVal(out.getValue, out.getMetadata) + } + /** Cast a Spark value from `dataType` into the variant type. */ def castToVariant(input: Any, dataType: DataType): VariantVal = { // Enforce strict check because it is illegal for input struct/map/variant to contain duplicate diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index bf5e4183de471..94285f22d9e33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -32,8 +32,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.json.JsonInferSchema import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, VARIANT_GET} -import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, QuotingUtils} import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors} @@ -804,6 +804,62 @@ object VariantDelete { } } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(v[, includeArrays]) - Recursively removes null fields from variant objects, " + + "and null elements from arrays unless `includeArrays` is false. " + + "Returns NULL if any argument is NULL.", + arguments = """ + Arguments: + * v - A variant value to mutate. + * includeArrays - An optional boolean (default true). If false, null array elements are + kept, but null fields of nested objects are still removed. + """, + examples = """ + Examples: + > SELECT _FUNC_(parse_json('{"a": 1, "b": null, "c": 3}')); + {"a":1,"c":3} + > SELECT _FUNC_(parse_json('[1, null, 3]')); + [1,3] + > SELECT _FUNC_(parse_json('{"a": {"b": null, "c": [1, null]}}')); + {"a":{"c":[1]}} + > SELECT _FUNC_(parse_json('{"a": [1, null], "b": null}'), false); + {"a":[1,null]} + > SELECT _FUNC_(NULL); + NULL + """, + since = "4.3.0", + group = "variant_funcs" +) +// scalastyle:on line.size.limit +case class VariantStripNulls(child: Expression, includeArrays: Expression) + extends RuntimeReplaceable + with ExpectsInputTypes + with BinaryLike[Expression] { + + def this(child: Expression) = this(child, Literal(true)) + + override def left: Expression = child + override def right: Expression = includeArrays + + override def dataType: DataType = VariantType + override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, BooleanType) + + override lazy val replacement: Expression = StaticInvoke( + VariantExpressionEvalUtils.getClass, + VariantType, + "stripNulls", + Seq(child, includeArrays), + inputTypes, + returnNullable = false) + + override def prettyName: String = "variant_strip_nulls" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): VariantStripNulls = + copy(child = newLeft, includeArrays = newRight) +} + case class VariantExplode(child: Expression) extends UnaryExpression with Generator with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(VariantType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala index 467cb6335a2e1..674c0d3d12843 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeTestUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.types.variant.VariantBuilder +import org.apache.spark.types.variant.{Variant, VariantBuilder} import org.apache.spark.types.variant.VariantUtil._ import org.apache.spark.unsafe.types.{UTF8String, VariantVal} import org.apache.spark.util.collection.Utils.createArray @@ -1311,4 +1311,59 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { noPaths.checkInputDataTypes() } } + + test("variant_strip_nulls") { + // Strip `input`, render the result back to JSON, and compare. `includeArrays` defaults to true. + def check(input: String, expected: String, includeArrays: Boolean = true): Unit = { + val expr = VariantStripNulls(Literal(parseJson(input)), Literal(includeArrays)) + val result = replace(expr).eval().asInstanceOf[VariantVal] + val json = if (result == null) null + else new Variant(result.getValue, result.getMetadata).toJson(ZoneOffset.UTC) + assert(json == expected) + } + + // The single-argument constructor defaults `includeArrays` to true. + assert(new VariantStripNulls(Literal(parseJson("[1, null]"))).includeArrays == Literal(true)) + + check("""{"a": 1, "b": null, "c": 3}""", """{"a":1,"c":3}""") + check("[1, null, 3]", "[1,3]") + check("""{"user": {"name": "Alice", "age": null}}""", """{"user":{"name":"Alice"}}""") + check("""{"a": [1, null, {"b": null, "c": 2}]}""", """{"a":[1,{"c":2}]}""") + check("[[1, null], [null]]", "[[1],[]]") + + // Empty containers are preserved; the parent is never collapsed. + check("""{"a": null}""", "{}") + check("[null]", "[]") + check("""{"a": {"b": null}}""", """{"a":{}}""") + check("""{"a": [null]}""", """{"a":[]}""") + check("{}", "{}") + check("[]", "[]") + + // Top-level variant null and scalars are returned unchanged. + check("null", "null") + check("42", "42") + check("\"hi\"", "\"hi\"") + + check("""{"a": {"b": {"c": null, "d": 4}}}""", """{"a":{"b":{"d":4}}}""") + + // `includeArrays = false`. + check("""{"a": [1, null, 3], "b": null}""", """{"a":[1,null,3]}""", includeArrays = false) + check( + """[{"a": 1, "b": null}, null, {"c": null, "d": 4}]""", + """[{"a":1},null,{"d":4}]""", + includeArrays = false) + // `includeArrays = true` (explicit) strips array nulls. + check("""{"a": [1, null, 3]}""", """{"a":[1,3]}""", includeArrays = true) + + // SQL NULL variant input yields SQL NULL. + checkEvaluation( + Cast(new VariantStripNulls(Literal.create(null, VariantType)), StringType), null) + // NULL `includeArrays` yields SQL NULL (the expression is null intolerant). + checkEvaluation( + Cast( + VariantStripNulls( + Literal(parseJson("""{"a": null}""")), Literal.create(null, BooleanType)), + StringType), + null) + } } diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index a74b25459bad2..4a308e09a47df 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -2775,6 +2775,10 @@ class PlanGenerationTestSuite extends ConnectFunSuite with Logging { fn.variant_delete(fn.parse_json(fn.col("g")), "$.a", "$.b") } + functionTest("variant_strip_nulls") { + fn.variant_strip_nulls(fn.parse_json(fn.col("g")), false) + } + functionTest("variant_get") { fn.variant_get(fn.parse_json(fn.col("g")), "$", "int") } diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_variant_strip_nulls.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_variant_strip_nulls.explain new file mode 100644 index 0000000000000..59ccba1f34f73 --- /dev/null +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_variant_strip_nulls.explain @@ -0,0 +1,2 @@ +Project [variant_strip_nulls(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, false, true, true)), false) AS variant_strip_nulls(parse_json(g), false)#0] ++- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_variant_strip_nulls.json b/sql/connect/common/src/test/resources/query-tests/queries/function_variant_strip_nulls.json new file mode 100644 index 0000000000000..19a249ce81888 --- /dev/null +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_variant_strip_nulls.json @@ -0,0 +1,104 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "variant_strip_nulls", + "arguments": [{ + "unresolvedFunction": { + "functionName": "parse_json", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "g" + }, + "common": { + "origin": { + "jvmOrigin": { + "stackTrace": [{ + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.functions$", + "methodName": "col", + "fileName": "functions.scala" + }, { + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", + "methodName": "~~trimmed~anonfun~~", + "fileName": "PlanGenerationTestSuite.scala" + }] + } + } + } + }], + "isInternal": false + }, + "common": { + "origin": { + "jvmOrigin": { + "stackTrace": [{ + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.functions$", + "methodName": "parse_json", + "fileName": "functions.scala" + }, { + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", + "methodName": "~~trimmed~anonfun~~", + "fileName": "PlanGenerationTestSuite.scala" + }] + } + } + } + }, { + "literal": { + "boolean": false + }, + "common": { + "origin": { + "jvmOrigin": { + "stackTrace": [{ + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.functions$", + "methodName": "variant_strip_nulls", + "fileName": "functions.scala" + }, { + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", + "methodName": "~~trimmed~anonfun~~", + "fileName": "PlanGenerationTestSuite.scala" + }] + } + } + } + }], + "isInternal": false + }, + "common": { + "origin": { + "jvmOrigin": { + "stackTrace": [{ + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.functions$", + "methodName": "variant_strip_nulls", + "fileName": "functions.scala" + }, { + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", + "methodName": "~~trimmed~anonfun~~", + "fileName": "PlanGenerationTestSuite.scala" + }] + } + } + } + }] + } +} \ No newline at end of file diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_variant_strip_nulls.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_variant_strip_nulls.proto.bin new file mode 100644 index 0000000000000000000000000000000000000000..e1a82aceffcb3b167932be0b974ee7874a9b6de7 GIT binary patch literal 967 zcmdUtu}i~16vq1+QMko&T_hlaI7m4%bPORw99$d(om|3ea*fgJCEi^GmkdsVTmK0~ zP!L2x5M0E6#6QFf1#zgfTXDJJ@%_I0KHekk4!G|tWDKO;X&qeyy#j4im6#w;{Xw4P z(lGR!ci^W*H4vPo{tQLPxQ|NVMPRv*4gC?8fn{=;u~}%yvc}#@L-%&s{aPa3b5uD> z6F1{BmX#2iUOMh45K^MjJU-Edq|ZPd0@)buf!qPTfhZsjmgVL~p)wCkOpdvyEpY1T zGvN&@nV7tgdaLIoh3K#T3_P8nz$ly0Y%0AkQ0{Wc6*eA+T$_U`H@w%aYa4Pj;%SYt zkk+%V`|E$GCm(IM#_$B>5$FR%J99?4!a
m8$Scqni1y};KW(Lp>(OlB HpI7(>o3U8( literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index d8abc643afeb0..7447b30d23d9c 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -557,6 +557,7 @@ | org.apache.spark.sql.catalyst.expressions.variant.TryVariantGetExpressionBuilder | try_variant_get | SELECT try_variant_get(parse_json('{"a": 1}'), '$.a', 'int') | struct | | org.apache.spark.sql.catalyst.expressions.variant.VariantDelete | variant_delete | SELECT variant_delete(parse_json('{"a": 1, "b": 2, "c": 3, "items": [1, 2, 3]}'), NULL, '$.a', '$.c') | struct | | org.apache.spark.sql.catalyst.expressions.variant.VariantGetExpressionBuilder | variant_get | SELECT variant_get(parse_json('{"a": 1}'), '$.a', 'int') | struct | +| org.apache.spark.sql.catalyst.expressions.variant.VariantStripNulls | variant_strip_nulls | SELECT variant_strip_nulls(parse_json('{"a": 1, "b": null, "c": 3}')) | struct | | org.apache.spark.sql.catalyst.expressions.xml.XPathBoolean | xpath_boolean | SELECT xpath_boolean('1','a/b') | struct1, a/b):boolean> | | org.apache.spark.sql.catalyst.expressions.xml.XPathDouble | xpath_double | SELECT xpath_double('12', 'sum(a/b)') | struct12, sum(a/b)):double> | | org.apache.spark.sql.catalyst.expressions.xml.XPathDouble | xpath_number | SELECT xpath_number('12', 'sum(a/b)') | struct12, sum(a/b)):double> | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala index 787292f55b5a4..607c4ac0e9409 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala @@ -295,6 +295,68 @@ class VariantSuite extends SharedSparkSession with ExpressionEvalHelper { } } + test("variant_strip_nulls with literal arguments") { + def rows(results: Any*): Seq[Row] = results.map(Row(_)) + + checkAnswer( + sql("SELECT to_json(variant_strip_nulls(parse_json('{\"a\": 1, \"b\": null, \"c\": 3}')))"), + rows("""{"a":1,"c":3}""")) + + checkAnswer( + sql("SELECT to_json(variant_strip_nulls(parse_json('[1, null, 3]')))"), + rows("[1,3]")) + + checkAnswer( + sql("SELECT to_json(variant_strip_nulls(parse_json(" + + "'{\"a\": [null, 3, {\"b\": null, \"c\": [null, 1]}], \"d\": null, " + + "\"e\": {\"f\": null, \"g\": 2}}')))"), + rows("""{"a":[3,{"c":[1]}],"e":{"g":2}}""")) + + // include_arrays = false keeps array null elements but still strips null fields of objects. + checkAnswer( + sql("SELECT to_json(variant_strip_nulls(" + + "parse_json('[{\"a\": 1, \"b\": null}, null, {\"c\": null, \"d\": 4}]'), false))"), + rows("""[{"a":1},null,{"d":4}]""")) + + // Empty containers are preserved. + checkAnswer( + sql("SELECT to_json(variant_strip_nulls(parse_json('{\"a\": null}')))"), + rows("{}")) + + checkAnswer( + sql("SELECT to_json(variant_strip_nulls(parse_json('[null, null]')))"), + rows("[]")) + + // Top-level variant null is unchanged. + checkAnswer( + sql("SELECT to_json(variant_strip_nulls(parse_json('null')))"), + rows("null")) + + checkAnswer( + sql("SELECT to_json(variant_strip_nulls(CAST(NULL AS VARIANT)))"), + rows(null)) + } + + test("variant_strip_nulls with dynamic arguments") { + def rows(results: Any*): Seq[Row] = results.map(Row(_)) + val df = Seq( + """{"a": [1, null], "b": null}""", + """{"x": null, "y": 2}""", + null + ).toDF("json") + val v = parse_json(col("json")) + + // Single-argument overload defaults include_arrays to true. + checkAnswer( + df.select(to_json(variant_strip_nulls(v)).alias("r")), + rows("""{"a":[1]}""", """{"y":2}""", null)) + + // Boolean overload with include_arrays = false preserves array null elements. + checkAnswer( + df.select(to_json(variant_strip_nulls(v, false)).alias("r")), + rows("""{"a":[1,null]}""", """{"y":2}""", null)) + } + test("round trip tests") { withSQLConf(SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key -> "false") { val rand = new Random(42) From a0732b1f676272ece2e7363c60767e29e114dbe3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bojana=20Ze=C4=8Devi=C4=87?= Date: Tue, 30 Jun 2026 10:30:47 +0000 Subject: [PATCH 2/2] test fix --- python/pyspark/sql/connect/functions/builtin.py | 7 ++----- python/pyspark/sql/functions/builtin.py | 13 ++++--------- python/pyspark/sql/tests/test_functions.py | 12 +----------- .../function_variant_strip_nulls.explain | 2 +- 4 files changed, 8 insertions(+), 26 deletions(-) diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index b53e703d9145c..f5455811176f6 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -2219,11 +2219,8 @@ def variant_delete(v: "ColumnOrName", *paths: Union[Column, str]) -> Column: variant_delete.__doc__ = pysparkfuncs.variant_delete.__doc__ -def variant_strip_nulls(v: "ColumnOrName", include_arrays: Union[Column, bool] = True) -> Column: - include_arrays_col = ( - include_arrays if isinstance(include_arrays, Column) else lit(include_arrays) - ) - return _invoke_function("variant_strip_nulls", _to_col(v), include_arrays_col) +def variant_strip_nulls(v: "ColumnOrName", include_arrays: bool = True) -> Column: + return _invoke_function("variant_strip_nulls", _to_col(v), lit(include_arrays)) variant_strip_nulls.__doc__ = pysparkfuncs.variant_strip_nulls.__doc__ diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 0858b941a01e5..6fc1234fcf7a7 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -21627,10 +21627,10 @@ def variant_delete(v: "ColumnOrName", *paths: Union[Column, str]) -> Column: @_try_remote_functions -def variant_strip_nulls(v: "ColumnOrName", include_arrays: Union[Column, bool] = True) -> Column: +def variant_strip_nulls(v: "ColumnOrName", include_arrays: bool = True) -> Column: """ Recursively removes null fields from variant objects, and null elements from arrays unless - `include_arrays` is False. Returns NULL if any argument is NULL. + `include_arrays` is False. Returns NULL if `v` is NULL. .. versionadded:: 4.3.0 @@ -21638,7 +21638,7 @@ def variant_strip_nulls(v: "ColumnOrName", include_arrays: Union[Column, bool] = ---------- v : :class:`~pyspark.sql.Column` or str a variant column or column name - include_arrays : :class:`~pyspark.sql.Column` or bool, optional + include_arrays : bool, optional whether null elements are also removed from arrays. If False, array null elements are kept while null fields of nested objects are still removed. Defaults to True. @@ -21663,13 +21663,8 @@ def variant_strip_nulls(v: "ColumnOrName", include_arrays: Union[Column, bool] = """ from pyspark.sql.classic.column import _to_java_column - include_arrays_col = ( - include_arrays if isinstance(include_arrays, Column) else lit(include_arrays) - ) return _invoke_function( - "variant_strip_nulls", - _to_java_column(v), - _to_java_column(include_arrays_col), + "variant_strip_nulls", _to_java_column(v), _enum_to_value(include_arrays) ) diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 70162956e0cc9..c6dc1d9fd578a 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -3525,17 +3525,7 @@ def check(resultDf, expected): ['{"a":1}', '{"b":2}'], ) check(df.select(F.to_json(F.variant_strip_nulls(v))), ['{"a":1}', '{"b":2}']) - check( - df.select(F.to_json(F.variant_strip_nulls(F.parse_json(F.lit('{"a": 1, "b": null}'))))), - ['{"a":1}', '{"a":1}'], - ) - inc = df.path == "$.a" - check( - df.select( - F.to_json(F.variant_strip_nulls(F.parse_json(F.lit('{"x": [1, null]}')), inc)) - ), - ['{"x":[1]}', '{"x":[1,null]}'], - ) + check(df.select(F.to_json(F.variant_strip_nulls(v, False))), ['{"a":1}', '{"b":2}']) check(df.select(F.schema_of_variant(v)), ["OBJECT", "OBJECT"]) check(df.select(F.schema_of_variant_agg(v)), ["OBJECT"]) diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_variant_strip_nulls.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_variant_strip_nulls.explain index 59ccba1f34f73..f6d8330de8280 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_variant_strip_nulls.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_variant_strip_nulls.explain @@ -1,2 +1,2 @@ -Project [variant_strip_nulls(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, false, true, true)), false) AS variant_strip_nulls(parse_json(g), false)#0] +Project [static_invoke(VariantExpressionEvalUtils.stripNulls(static_invoke(VariantExpressionEvalUtils.parseJson(g#0, false, true, true)), false)) AS variant_strip_nulls(parse_json(g), false)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]