Skip to content

Commit a34467d

Browse files
committed
[#37994] Fix NullPointerException in Spark Runner with multiple outputs and serialization
1 parent cf3d6ed commit a34467d

5 files changed

Lines changed: 78 additions & 2 deletions

File tree

runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
*/
1818
package org.apache.beam.runners.spark.translation;
1919

20+
import java.util.Objects;
2021
import static org.apache.beam.runners.spark.translation.TranslationUtils.canAvoidRddSerialization;
2122
import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
2223
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;
@@ -486,6 +487,8 @@ public void evaluate(
486487
TranslationUtils.getTupleTagCoders(outputs);
487488
all =
488489
all.mapToPair(TranslationUtils.getTupleTagEncodeFunction(coderMap))
490+
.filter(Objects::nonNull) // skip nulls to save on encoding, nulls are tags that
491+
// are not read
489492
.persist(level)
490493
.mapToPair(TranslationUtils.getTupleTagDecodeFunction(coderMap));
491494
}

runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,8 +445,13 @@ public static Map<TupleTag<?>, Coder<WindowedValue<?>>> getTupleTagCoders(
445445
return tuple2 -> {
446446
TupleTag<?> tupleTag = tuple2._1;
447447
WindowedValue<?> windowedValue = tuple2._2;
448-
return new Tuple2<>(
449-
tupleTag, ValueAndCoderLazySerializable.of(windowedValue, coderMap.get(tupleTag)));
448+
Coder<WindowedValue<?>> coder = coderMap.get(tupleTag);
449+
if (coder == null) {
450+
// there is no coder as this step is leaf step and is not read anywhere, so coder is pruned
451+
// from coderMap
452+
return null;
453+
}
454+
return new Tuple2<>(tupleTag, ValueAndCoderLazySerializable.of(windowedValue, coder));
450455
};
451456
}
452457

runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluator.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
*/
1818
package org.apache.beam.runners.spark.translation.streaming;
1919

20+
import java.util.Objects;
2021
import static org.apache.beam.runners.spark.translation.TranslationUtils.getBatchDuration;
2122
import static org.apache.beam.runners.spark.translation.TranslationUtils.hasEventTimers;
2223
import static org.apache.beam.runners.spark.translation.TranslationUtils.hasTimers;
@@ -234,6 +235,7 @@ public void evaluate(
234235
TranslationUtils.getTupleTagCoders(outputs);
235236
all =
236237
all.mapToPair(TranslationUtils.getTupleTagEncodeFunction(coderMap))
238+
.filter(Objects::nonNull) // skip nulls to save on encoding, nulls are tags that are not read
237239
.cache()
238240
.mapToPair(TranslationUtils.getTupleTagDecodeFunction(coderMap));
239241

runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
*/
1818
package org.apache.beam.runners.spark.translation.streaming;
1919

20+
import java.util.Objects;
2021
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
2122
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;
2223

@@ -593,6 +594,7 @@ public void evaluate(
593594
TranslationUtils.getTupleTagCoders(outputs);
594595
all =
595596
all.mapToPair(TranslationUtils.getTupleTagEncodeFunction(coderMap))
597+
.filter(Objects::nonNull) // skip nulls to save on encoding, nulls are tags that are not read
596598
.cache()
597599
.mapToPair(TranslationUtils.getTupleTagDecodeFunction(coderMap));
598600
}

runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TransformTranslatorTest.java

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
import org.apache.beam.sdk.coders.VarIntCoder;
4141
import org.apache.beam.sdk.transforms.Count;
4242
import org.apache.beam.sdk.transforms.Create;
43+
import org.apache.beam.sdk.transforms.DoFn;
44+
import org.apache.beam.sdk.transforms.ParDo;
4345
import org.apache.beam.sdk.transforms.PTransform;
4446
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
4547
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
@@ -48,6 +50,7 @@
4850
import org.apache.beam.sdk.values.PCollection;
4951
import org.apache.beam.sdk.values.PCollectionTuple;
5052
import org.apache.beam.sdk.values.TupleTag;
53+
import org.apache.beam.sdk.values.TupleTagList;
5154
import org.apache.beam.sdk.values.WindowedValue;
5255
import org.apache.beam.sdk.values.WindowedValues;
5356
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
@@ -247,4 +250,65 @@ public void testMultipleOutputParDoShouldHaveFilterWhenSideOutputIsConsumed() {
247250
assertTrue(parsed.stream().anyMatch(e -> e.getName().contains(tag.getId())));
248251
}
249252
}
253+
254+
@Test
255+
public void testMultipleOutputParDoWithUnconsumedSideOutputAndSerializationStorageLevel() {
256+
Pipeline p = Pipeline.create();
257+
TupleTag<String> tag1 = new TupleTag<String>("tag1") {};
258+
TupleTag<String> tag2 = new TupleTag<String>("tag2") {};
259+
TupleTag<String> tag3 = new TupleTag<String>("tag3") {};
260+
261+
SparkPipelineOptions options = contextRule.createPipelineOptions();
262+
// Force serialization by setting storage level to MEMORY_AND_DISK_SER
263+
options.setStorageLevel("MEMORY_AND_DISK_SER");
264+
265+
TransformTranslator.Translator translator = new TransformTranslator.Translator();
266+
267+
PTransform<PBegin, PCollection<String>> createTransform = Create.of("foo", "bar");
268+
269+
PCollectionTuple pCollectionTuple =
270+
p.apply("Create Values", createTransform)
271+
.apply(
272+
"Multiple Output ParDo",
273+
ParDo.of(new MultiOutputDoFn(tag1, tag2, tag3))
274+
.withOutputTags(tag1, TupleTagList.of(tag2).and(tag3)));
275+
276+
// consume tag1 and tag2
277+
pCollectionTuple.get(tag1).apply("Count1", Count.globally());
278+
pCollectionTuple.get(tag2).apply("Count2", Count.globally());
279+
280+
p.replaceAll(SparkTransformOverrides.getDefaultOverrides(false));
281+
282+
EvaluationContext ctxt = new EvaluationContext(contextRule.getSparkContext(), p, options);
283+
SparkRunner.initAccumulators(options, ctxt.getSparkContext());
284+
SparkRunner.updateDependentTransforms(p, translator, ctxt);
285+
286+
// This should not throw NullPointerException
287+
p.traverseTopologically(new SparkRunner.Evaluator(translator, ctxt));
288+
289+
// Also trigger some action on the RDD to ensure serialization happens
290+
@SuppressWarnings("unchecked")
291+
BoundedDataset<String> dataset =
292+
(BoundedDataset<String>) ctxt.borrowDataset(pCollectionTuple.get(tag1));
293+
dataset.getRDD().count();
294+
}
295+
296+
private static class MultiOutputDoFn extends DoFn<String, String> {
297+
private final TupleTag<String> tag1;
298+
private final TupleTag<String> tag2;
299+
private final TupleTag<String> tag3;
300+
301+
MultiOutputDoFn(TupleTag<String> tag1, TupleTag<String> tag2, TupleTag<String> tag3) {
302+
this.tag1 = tag1;
303+
this.tag2 = tag2;
304+
this.tag3 = tag3;
305+
}
306+
307+
@ProcessElement
308+
public void process(@Element String input, MultiOutputReceiver outputReceiver) {
309+
outputReceiver.get(tag1).output(input);
310+
outputReceiver.get(tag2).output(input);
311+
outputReceiver.get(tag3).output(input);
312+
}
313+
}
250314
}

0 commit comments

Comments
 (0)