Skip to content

Commit e2be9fd

Browse files
authored
Plumb custom batch parameters for autosharding from WriteFiles to FileIO. (#37463)
* Address circular dependencies. * Fix formatting. * Fix tests. * Fix lint. * Remove unused import. * Resolve circular dependency without removing __repr__. * Fix formatting. * Remove nextmark_json_util and move all its methods into nextmark_model. * Restore millis_to_timestamp. * Plumb custom batch params and add tests. * Fix formatting and imports. * Fix imports and test. * Add missing import. * Add another test case for byte count. * Added checks for positive values.
1 parent 15ba487 commit e2be9fd

2 files changed

Lines changed: 205 additions & 0 deletions

File tree

sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,12 @@ public static FileNaming relativeFileNaming(
10591059

10601060
abstract @Nullable Integer getMaxNumWritersPerBundle();
10611061

1062+
abstract @Nullable Integer getBatchSize();
1063+
1064+
abstract @Nullable Integer getBatchSizeBytes();
1065+
1066+
abstract @Nullable Duration getBatchMaxBufferingDuration();
1067+
10621068
abstract @Nullable ErrorHandler<BadRecord, ?> getBadRecordErrorHandler();
10631069

10641070
abstract Builder<DestinationT, UserT> toBuilder();
@@ -1112,6 +1118,13 @@ abstract Builder<DestinationT, UserT> setSharding(
11121118
abstract Builder<DestinationT, UserT> setMaxNumWritersPerBundle(
11131119
@Nullable Integer maxNumWritersPerBundle);
11141120

1121+
abstract Builder<DestinationT, UserT> setBatchSize(@Nullable Integer batchSize);
1122+
1123+
abstract Builder<DestinationT, UserT> setBatchSizeBytes(@Nullable Integer batchSizeBytes);
1124+
1125+
abstract Builder<DestinationT, UserT> setBatchMaxBufferingDuration(
1126+
@Nullable Duration batchMaxBufferingDuration);
1127+
11151128
abstract Builder<DestinationT, UserT> setBadRecordErrorHandler(
11161129
@Nullable ErrorHandler<BadRecord, ?> badRecordErrorHandler);
11171130

@@ -1301,6 +1314,7 @@ public Write<DestinationT, UserT> withDestinationCoder(Coder<DestinationT> desti
13011314
*/
13021315
public Write<DestinationT, UserT> withNumShards(int numShards) {
13031316
checkArgument(numShards >= 0, "numShards must be non-negative, but was: %s", numShards);
1317+
checkArgument(!getAutoSharding(), "Cannot set numShards when withAutoSharding() is used");
13041318
if (numShards == 0) {
13051319
return withNumShards(null);
13061320
}
@@ -1311,6 +1325,7 @@ public Write<DestinationT, UserT> withNumShards(int numShards) {
13111325
* Like {@link #withNumShards(int)}. Specifying {@code null} means runner-determined sharding.
13121326
*/
13131327
public Write<DestinationT, UserT> withNumShards(@Nullable ValueProvider<Integer> numShards) {
1328+
checkArgument(!getAutoSharding(), "Cannot set numShards when withAutoSharding() is used");
13141329
return toBuilder().setNumShards(numShards).build();
13151330
}
13161331

@@ -1321,6 +1336,7 @@ public Write<DestinationT, UserT> withNumShards(@Nullable ValueProvider<Integer>
13211336
public Write<DestinationT, UserT> withSharding(
13221337
PTransform<PCollection<UserT>, PCollectionView<Integer>> sharding) {
13231338
checkArgument(sharding != null, "sharding can not be null");
1339+
checkArgument(!getAutoSharding(), "Cannot set sharding when withAutoSharding() is used");
13241340
return toBuilder().setSharding(sharding).build();
13251341
}
13261342

@@ -1337,6 +1353,9 @@ public Write<DestinationT, UserT> withIgnoreWindowing() {
13371353
}
13381354

13391355
public Write<DestinationT, UserT> withAutoSharding() {
1356+
checkArgument(
1357+
getNumShards() == null && getSharding() == null,
1358+
"Cannot use withAutoSharding() when withNumShards() or withSharding() is set");
13401359
return toBuilder().setAutoSharding(true).build();
13411360
}
13421361

@@ -1366,6 +1385,44 @@ public Write<DestinationT, UserT> withBadRecordErrorHandler(
13661385
return toBuilder().setBadRecordErrorHandler(errorHandler).build();
13671386
}
13681387

1388+
/**
1389+
* Returns a new {@link Write} that will batch the input records using specified batch size. The
1390+
* default value is {@link WriteFiles#FILE_TRIGGERING_RECORD_COUNT}.
1391+
*
1392+
* <p>This option is used only for writing unbounded data with auto-sharding.
1393+
*/
1394+
public Write<DestinationT, UserT> withBatchSize(@Nullable Integer batchSize) {
1395+
checkArgument(batchSize > 0, "batchSize must be positive, but was: %s", batchSize);
1396+
return toBuilder().setBatchSize(batchSize).build();
1397+
}
1398+
1399+
/**
1400+
* Returns a new {@link Write} that will batch the input records using specified batch size in
1401+
* bytes. The default value is {@link WriteFiles#FILE_TRIGGERING_BYTE_COUNT}.
1402+
*
1403+
* <p>This option is used only for writing unbounded data with auto-sharding.
1404+
*/
1405+
public Write<DestinationT, UserT> withBatchSizeBytes(@Nullable Integer batchSizeBytes) {
1406+
checkArgument(
1407+
batchSizeBytes > 0, "batchSizeBytes must be positive, but was: %s", batchSizeBytes);
1408+
return toBuilder().setBatchSizeBytes(batchSizeBytes).build();
1409+
}
1410+
1411+
/**
1412+
* Returns a new {@link Write} that will batch the input records using specified max buffering
1413+
* duration. The default value is {@link WriteFiles#FILE_TRIGGERING_RECORD_BUFFERING_DURATION}.
1414+
*
1415+
* <p>This option is used only for writing unbounded data with auto-sharding.
1416+
*/
1417+
public Write<DestinationT, UserT> withBatchMaxBufferingDuration(
1418+
@Nullable Duration batchMaxBufferingDuration) {
1419+
checkArgument(
1420+
batchMaxBufferingDuration.isLongerThan(Duration.ZERO),
1421+
"batchMaxBufferingDuration must be positive, but was: %s",
1422+
batchMaxBufferingDuration);
1423+
return toBuilder().setBatchMaxBufferingDuration(batchMaxBufferingDuration).build();
1424+
}
1425+
13691426
@VisibleForTesting
13701427
Contextful<Fn<DestinationT, FileNaming>> resolveFileNamingFn() {
13711428
if (getDynamic()) {
@@ -1482,6 +1539,15 @@ public WriteFilesResult<DestinationT> expand(PCollection<UserT> input) {
14821539
if (getBadRecordErrorHandler() != null) {
14831540
writeFiles = writeFiles.withBadRecordErrorHandler(getBadRecordErrorHandler());
14841541
}
1542+
if (getBatchSize() != null) {
1543+
writeFiles = writeFiles.withBatchSize(getBatchSize());
1544+
}
1545+
if (getBatchSizeBytes() != null) {
1546+
writeFiles = writeFiles.withBatchSizeBytes(getBatchSizeBytes());
1547+
}
1548+
if (getBatchMaxBufferingDuration() != null) {
1549+
writeFiles = writeFiles.withBatchMaxBufferingDuration(getBatchMaxBufferingDuration());
1550+
}
14851551
return input.apply(writeFiles);
14861552
}
14871553

sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileIOTest.java

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,14 @@
1919

2020
import static org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions.RESOLVE_FILE;
2121
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects.firstNonNull;
22+
import static org.hamcrest.MatcherAssert.assertThat;
23+
import static org.hamcrest.Matchers.containsInAnyOrder;
2224
import static org.hamcrest.Matchers.isA;
2325
import static org.junit.Assert.assertEquals;
2426
import static org.junit.Assert.assertFalse;
2527
import static org.junit.Assert.assertTrue;
2628

29+
import java.io.BufferedReader;
2730
import java.io.File;
2831
import java.io.FileNotFoundException;
2932
import java.io.FileOutputStream;
@@ -38,38 +41,48 @@
3841
import java.nio.file.Paths;
3942
import java.nio.file.StandardCopyOption;
4043
import java.nio.file.attribute.FileTime;
44+
import java.util.ArrayList;
4145
import java.util.Arrays;
46+
import java.util.Collections;
4247
import java.util.List;
4348
import java.util.Objects;
4449
import java.util.zip.GZIPOutputStream;
4550
import org.apache.beam.sdk.coders.StringUtf8Coder;
4651
import org.apache.beam.sdk.coders.VarIntCoder;
4752
import org.apache.beam.sdk.io.fs.EmptyMatchTreatment;
4853
import org.apache.beam.sdk.io.fs.MatchResult;
54+
import org.apache.beam.sdk.io.fs.MatchResult.Metadata;
4955
import org.apache.beam.sdk.options.PipelineOptionsFactory;
5056
import org.apache.beam.sdk.state.StateSpec;
5157
import org.apache.beam.sdk.state.StateSpecs;
5258
import org.apache.beam.sdk.state.ValueState;
5359
import org.apache.beam.sdk.testing.NeedsRunner;
5460
import org.apache.beam.sdk.testing.PAssert;
5561
import org.apache.beam.sdk.testing.TestPipeline;
62+
import org.apache.beam.sdk.testing.UsesUnboundedPCollections;
5663
import org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo;
5764
import org.apache.beam.sdk.transforms.Contextful;
5865
import org.apache.beam.sdk.transforms.Create;
5966
import org.apache.beam.sdk.transforms.DoFn;
6067
import org.apache.beam.sdk.transforms.MapElements;
68+
import org.apache.beam.sdk.transforms.PTransform;
6169
import org.apache.beam.sdk.transforms.ParDo;
6270
import org.apache.beam.sdk.transforms.Requirements;
6371
import org.apache.beam.sdk.transforms.SerializableFunctions;
6472
import org.apache.beam.sdk.transforms.View;
6573
import org.apache.beam.sdk.transforms.Watch;
74+
import org.apache.beam.sdk.transforms.windowing.AfterWatermark;
75+
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
6676
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
6777
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
78+
import org.apache.beam.sdk.transforms.windowing.Window;
6879
import org.apache.beam.sdk.values.KV;
6980
import org.apache.beam.sdk.values.PCollection;
81+
import org.apache.beam.sdk.values.PCollection.IsBounded;
7082
import org.apache.beam.sdk.values.PCollectionView;
7183
import org.apache.beam.sdk.values.TypeDescriptor;
7284
import org.apache.beam.sdk.values.TypeDescriptors;
85+
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
7386
import org.joda.time.Duration;
7487
import org.junit.Rule;
7588
import org.junit.Test;
@@ -547,4 +560,130 @@ public void testFileIoDynamicNaming() throws IOException {
547560
"Output file shard 0 exists after pipeline completes",
548561
new File(outputFileName + "-0").exists());
549562
}
563+
564+
@Test
565+
@Category({NeedsRunner.class, UsesUnboundedPCollections.class})
566+
public void testWriteUnboundedWithCustomBatchSize() throws IOException {
567+
File root = tmpFolder.getRoot();
568+
List<String> inputs = Arrays.asList("one", "two", "three", "four", "five", "six");
569+
570+
PTransform<PCollection<String>, PCollection<String>> transform =
571+
Window.<String>into(FixedWindows.of(Duration.standardSeconds(10)))
572+
.triggering(AfterWatermark.pastEndOfWindow())
573+
.withAllowedLateness(Duration.ZERO)
574+
.discardingFiredPanes();
575+
576+
FileIO.Write<Void, String> write =
577+
FileIO.<String>write()
578+
.via(TextIO.sink())
579+
.to(root.getAbsolutePath())
580+
.withPrefix("output")
581+
.withSuffix(".txt")
582+
.withAutoSharding()
583+
.withBatchSize(3)
584+
.withBatchSizeBytes(1024 * 1024) // Set high to avoid triggering flushing by byte count.
585+
.withBatchMaxBufferingDuration(
586+
Duration.standardMinutes(1)); // Set high to avoid triggering flushing by duration.
587+
588+
// Prepare timestamps for the elements.
589+
List<Long> timestamps = new ArrayList<>();
590+
for (long i = 0; i < inputs.size(); i++) {
591+
timestamps.add(i + 1);
592+
}
593+
594+
p.apply(Create.timestamped(inputs, timestamps).withCoder(StringUtf8Coder.of()))
595+
.setIsBoundedInternal(IsBounded.UNBOUNDED)
596+
.apply(transform)
597+
.apply(write);
598+
p.run().waitUntilFinish();
599+
600+
// Verify that the custom batch parameters are set.
601+
assertEquals(3, write.getBatchSize().intValue());
602+
assertEquals(1024 * 1024, write.getBatchSizeBytes().intValue());
603+
assertEquals(Duration.standardMinutes(1), write.getBatchMaxBufferingDuration());
604+
605+
// Verify file contents.
606+
checkFileContents(root, "output", inputs);
607+
608+
// With auto-sharding, we can't assert on the exact number of output files, but because
609+
// batch size is 3 and there are 6 elements, we expect at least 2 files.
610+
final String pattern = new File(root, "output").getAbsolutePath() + "*";
611+
List<Metadata> metadata =
612+
FileSystems.match(Collections.singletonList(pattern)).get(0).metadata();
613+
assertTrue(metadata.size() >= 2);
614+
}
615+
616+
@Test
617+
@Category({NeedsRunner.class, UsesUnboundedPCollections.class})
618+
public void testWriteUnboundedWithCustomBatchSizeBytes() throws IOException {
619+
File root = tmpFolder.getRoot();
620+
// The elements plus newline characters give a total of 4+4+6+5+5+4=28 bytes.
621+
List<String> inputs = Arrays.asList("one", "two", "three", "four", "five", "six");
622+
// Assign timestamps so that all elements fall into the same 10s window.
623+
List<Long> timestamps = Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L);
624+
625+
FileIO.Write<Void, String> write =
626+
FileIO.<String>write()
627+
.via(TextIO.sink())
628+
.to(root.getAbsolutePath())
629+
.withPrefix("output")
630+
.withSuffix(".txt")
631+
.withAutoSharding()
632+
.withBatchSize(1000) // Set high to avoid flushing by record count.
633+
.withBatchSizeBytes(10)
634+
.withBatchMaxBufferingDuration(
635+
Duration.standardMinutes(1)); // Set high to avoid flushing by duration.
636+
637+
p.apply(Create.timestamped(inputs, timestamps).withCoder(StringUtf8Coder.of()))
638+
.setIsBoundedInternal(IsBounded.UNBOUNDED)
639+
.apply(
640+
Window.<String>into(FixedWindows.of(Duration.standardSeconds(10)))
641+
.triggering(AfterWatermark.pastEndOfWindow())
642+
.withAllowedLateness(Duration.ZERO)
643+
.discardingFiredPanes())
644+
.apply(write);
645+
646+
p.run().waitUntilFinish();
647+
648+
// Verify that the custom batch parameters are set.
649+
assertEquals(1000, write.getBatchSize().intValue());
650+
assertEquals(10, write.getBatchSizeBytes().intValue());
651+
assertEquals(Duration.standardMinutes(1), write.getBatchMaxBufferingDuration());
652+
checkFileContents(root, "output", inputs);
653+
654+
// With auto-sharding, we cannot assert on the exact number of output files. The BatchSizeBytes
655+
// acts as a threshold for flushing; once buffer size reaches 10 bytes, a flush is triggered,
656+
// but more items may be added before it completes. With 28 bytes total, we can only guarantee
657+
// at least 2 files are produced.
658+
final String pattern = new File(root, "output").getAbsolutePath() + "*";
659+
List<Metadata> metadata =
660+
FileSystems.match(Collections.singletonList(pattern)).get(0).metadata();
661+
assertTrue(metadata.size() >= 2);
662+
}
663+
664+
static void checkFileContents(File rootDir, String prefix, List<String> inputs)
665+
throws IOException {
666+
List<File> outputFiles = Lists.newArrayList();
667+
final String pattern = new File(rootDir, prefix).getAbsolutePath() + "*";
668+
List<Metadata> metadata =
669+
FileSystems.match(Collections.singletonList(pattern)).get(0).metadata();
670+
for (Metadata meta : metadata) {
671+
outputFiles.add(new File(meta.resourceId().toString()));
672+
}
673+
assertFalse("Should have produced at least 1 output file", outputFiles.isEmpty());
674+
675+
List<String> actual = Lists.newArrayList();
676+
for (File outputFile : outputFiles) {
677+
List<String> actualShard = Lists.newArrayList();
678+
try (BufferedReader reader =
679+
Files.newBufferedReader(outputFile.toPath(), StandardCharsets.UTF_8)) {
680+
String line;
681+
while ((line = reader.readLine()) != null) {
682+
actualShard.add(line);
683+
}
684+
}
685+
actual.addAll(actualShard);
686+
}
687+
assertThat(actual, containsInAnyOrder(inputs.toArray()));
688+
}
550689
}

0 commit comments

Comments
 (0)