Skip to content

Commit 9fe791d

Browse files
committed
SnowflakeIO: use BoundedSource to allow parallel reading of partitioned files generated by Snowflake COPY.
1 parent e82fc47 commit 9fe791d

2 files changed

Lines changed: 204 additions & 83 deletions

File tree

sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeIO.java

Lines changed: 194 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,30 @@
1717
*/
1818
package org.apache.beam.sdk.io.snowflake;
1919

20-
import static org.apache.beam.sdk.io.TextIO.readFiles;
2120
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
2221

2322
import com.google.auto.value.AutoValue;
2423
import com.opencsv.CSVParser;
2524
import com.opencsv.CSVParserBuilder;
25+
import java.io.BufferedReader;
2626
import java.io.IOException;
27+
import java.io.InputStream;
28+
import java.io.InputStreamReader;
2729
import java.io.Serializable;
30+
import java.nio.channels.Channels;
31+
import java.nio.channels.ReadableByteChannel;
32+
import java.nio.charset.StandardCharsets;
2833
import java.security.PrivateKey;
2934
import java.sql.SQLException;
3035
import java.time.LocalDateTime;
3136
import java.time.ZoneId;
3237
import java.util.ArrayList;
38+
import java.util.Collections;
3339
import java.util.List;
3440
import java.util.UUID;
3541
import java.util.concurrent.ConcurrentHashMap;
3642
import java.util.stream.Collectors;
43+
import java.util.zip.GZIPInputStream;
3744
import javax.annotation.Nullable;
3845
import javax.sql.DataSource;
3946
import net.snowflake.client.api.datasource.SnowflakeDataSource;
@@ -43,11 +50,13 @@
4350
import org.apache.beam.sdk.coders.Coder;
4451
import org.apache.beam.sdk.coders.ListCoder;
4552
import org.apache.beam.sdk.coders.StringUtf8Coder;
53+
import org.apache.beam.sdk.io.BoundedSource;
4654
import org.apache.beam.sdk.io.Compression;
4755
import org.apache.beam.sdk.io.FileIO;
4856
import org.apache.beam.sdk.io.FileSystems;
4957
import org.apache.beam.sdk.io.TextIO;
5058
import org.apache.beam.sdk.io.WriteFilesResult;
59+
import org.apache.beam.sdk.io.fs.MatchResult;
5160
import org.apache.beam.sdk.io.fs.MoveOptions;
5261
import org.apache.beam.sdk.io.fs.ResourceId;
5362
import org.apache.beam.sdk.io.snowflake.data.SnowflakeTableSchema;
@@ -59,6 +68,7 @@
5968
import org.apache.beam.sdk.io.snowflake.services.SnowflakeServices;
6069
import org.apache.beam.sdk.io.snowflake.services.SnowflakeServicesImpl;
6170
import org.apache.beam.sdk.io.snowflake.services.SnowflakeStreamingServiceConfig;
71+
import org.apache.beam.sdk.options.PipelineOptions;
6272
import org.apache.beam.sdk.options.ValueProvider;
6373
import org.apache.beam.sdk.transforms.Combine;
6474
import org.apache.beam.sdk.transforms.Create;
@@ -67,7 +77,6 @@
6777
import org.apache.beam.sdk.transforms.PTransform;
6878
import org.apache.beam.sdk.transforms.ParDo;
6979
import org.apache.beam.sdk.transforms.Reify;
70-
import org.apache.beam.sdk.transforms.Reshuffle;
7180
import org.apache.beam.sdk.transforms.SerializableFunction;
7281
import org.apache.beam.sdk.transforms.SimpleFunction;
7382
import org.apache.beam.sdk.transforms.Values;
@@ -424,31 +433,26 @@ public Read<T> withQuotationMark(ValueProvider<String> quotationMark) {
424433
public PCollection<T> expand(PBegin input) {
425434
checkArguments();
426435

427-
PCollection<Void> emptyCollection = input.apply(Create.of((Void) null));
428436
String tmpDirName = makeTmpDirName();
429-
PCollection<T> output =
430-
emptyCollection
431-
.apply(
432-
ParDo.of(
433-
new CopyIntoStageFn(
434-
getDataSourceProviderFn(),
435-
getQuery(),
436-
getTable(),
437-
getStorageIntegrationName(),
438-
getStagingBucketName(),
439-
tmpDirName,
440-
getSnowflakeServices(),
441-
getQuotationMark())))
442-
.apply(Reshuffle.viaRandomKey())
443-
.apply(FileIO.matchAll())
444-
.apply(FileIO.readMatches())
445-
.apply(readFiles())
446-
.apply(ParDo.of(new MapCsvToStringArrayFn(getQuotationMark())))
447-
.apply(ParDo.of(new MapStringArrayToUserDataFn<>(getCsvMapper())));
448437

438+
SnowflakeBoundedSource<T> source =
439+
new SnowflakeBoundedSource<>(
440+
getDataSourceProviderFn(),
441+
getQuery(),
442+
getTable(),
443+
getStorageIntegrationName(),
444+
getStagingBucketName(),
445+
tmpDirName,
446+
getSnowflakeServices(),
447+
getQuotationMark(),
448+
getCsvMapper(),
449+
getCoder());
450+
451+
PCollection<T> output = input.apply(org.apache.beam.sdk.io.Read.from(source));
449452
output.setCoder(getCoder());
450453

451-
emptyCollection
454+
input
455+
.apply(Create.of((Void) null))
452456
.apply(Wait.on(output))
453457
.apply(ParDo.of(new CleanTmpFilesFromGcsFn(getStagingBucketName(), tmpDirName)));
454458
return output;
@@ -483,103 +487,212 @@ private String makeTmpDirName() {
483487
);
484488
}
485489

486-
private static class CopyIntoStageFn extends DoFn<Object, String> {
490+
/**
491+
* A {@link BoundedSource} that reads from Snowflake by running COPY INTO to stage CSV files,
492+
* then splitting into one sub-source per file.
493+
*/
494+
private static class SnowflakeBoundedSource<T> extends BoundedSource<T> {
495+
private static final Logger LOG = LoggerFactory.getLogger(SnowflakeBoundedSource.class);
496+
487497
private final SerializableFunction<Void, DataSource> dataSourceProviderFn;
488-
private final ValueProvider<String> query;
489-
private final ValueProvider<String> database;
490-
private final ValueProvider<String> schema;
491-
private final ValueProvider<String> table;
498+
private final @Nullable ValueProvider<String> query;
499+
private final @Nullable ValueProvider<String> table;
492500
private final ValueProvider<String> storageIntegrationName;
493-
private final ValueProvider<String> stagingBucketDir;
501+
private final ValueProvider<String> stagingBucketName;
494502
private final String tmpDirName;
495503
private final SnowflakeServices snowflakeServices;
496504
private final ValueProvider<String> quotationMark;
505+
private final CsvMapper<T> csvMapper;
506+
private final Coder<T> coder;
507+
508+
// Non-null only for child sources (one per staged file)
509+
private final @Nullable String filePath;
510+
private final long fileSize;
497511

498-
private CopyIntoStageFn(
512+
SnowflakeBoundedSource(
499513
SerializableFunction<Void, DataSource> dataSourceProviderFn,
500-
ValueProvider<String> query,
501-
ValueProvider<String> table,
514+
@Nullable ValueProvider<String> query,
515+
@Nullable ValueProvider<String> table,
502516
ValueProvider<String> storageIntegrationName,
503-
ValueProvider<String> stagingBucketDir,
517+
ValueProvider<String> stagingBucketName,
504518
String tmpDirName,
505519
SnowflakeServices snowflakeServices,
506-
ValueProvider<String> quotationMark) {
520+
ValueProvider<String> quotationMark,
521+
CsvMapper<T> csvMapper,
522+
Coder<T> coder) {
523+
this(
524+
dataSourceProviderFn,
525+
query,
526+
table,
527+
storageIntegrationName,
528+
stagingBucketName,
529+
tmpDirName,
530+
snowflakeServices,
531+
quotationMark,
532+
csvMapper,
533+
coder,
534+
null,
535+
0);
536+
}
537+
538+
private SnowflakeBoundedSource(
539+
SerializableFunction<Void, DataSource> dataSourceProviderFn,
540+
@Nullable ValueProvider<String> query,
541+
@Nullable ValueProvider<String> table,
542+
ValueProvider<String> storageIntegrationName,
543+
ValueProvider<String> stagingBucketName,
544+
String tmpDirName,
545+
SnowflakeServices snowflakeServices,
546+
ValueProvider<String> quotationMark,
547+
CsvMapper<T> csvMapper,
548+
Coder<T> coder,
549+
@Nullable String filePath,
550+
long fileSize) {
507551
this.dataSourceProviderFn = dataSourceProviderFn;
508552
this.query = query;
509553
this.table = table;
510554
this.storageIntegrationName = storageIntegrationName;
555+
this.stagingBucketName = stagingBucketName;
556+
this.tmpDirName = tmpDirName;
511557
this.snowflakeServices = snowflakeServices;
512558
this.quotationMark = quotationMark;
513-
this.stagingBucketDir = stagingBucketDir;
514-
this.tmpDirName = tmpDirName;
515-
DataSourceProviderFromDataSourceConfiguration
516-
dataSourceProviderFromDataSourceConfiguration =
517-
(DataSourceProviderFromDataSourceConfiguration) this.dataSourceProviderFn;
518-
DataSourceConfiguration config = dataSourceProviderFromDataSourceConfiguration.getConfig();
519-
520-
this.database = config.getDatabase();
521-
this.schema = config.getSchema();
559+
this.csvMapper = csvMapper;
560+
this.coder = coder;
561+
this.filePath = filePath;
562+
this.fileSize = fileSize;
522563
}
523564

524-
@ProcessElement
525-
public void processElement(ProcessContext context) throws Exception {
526-
String databaseValue = getValueOrNull(this.database);
527-
String schemaValue = getValueOrNull(this.schema);
528-
String tableValue = getValueOrNull(this.table);
529-
String queryValue = getValueOrNull(this.query);
565+
@Override
566+
public List<? extends BoundedSource<T>> split(
567+
long desiredBundleSizeBytes, PipelineOptions options) throws Exception {
568+
if (filePath != null) {
569+
return Collections.singletonList(this);
570+
}
530571

531572
String stagingBucketRunDir =
532573
String.format(
533574
"%s/%s/run_%s/",
534-
stagingBucketDir.get(), tmpDirName, UUID.randomUUID().toString().subSequence(0, 8));
575+
stagingBucketName.get(),
576+
tmpDirName,
577+
UUID.randomUUID().toString().subSequence(0, 8));
535578

536-
SnowflakeBatchServiceConfig config =
579+
DataSourceProviderFromDataSourceConfiguration dsProvider =
580+
(DataSourceProviderFromDataSourceConfiguration) dataSourceProviderFn;
581+
DataSourceConfiguration config = dsProvider.getConfig();
582+
583+
SnowflakeBatchServiceConfig batchConfig =
537584
new SnowflakeBatchServiceConfig(
538585
dataSourceProviderFn,
539-
databaseValue,
540-
schemaValue,
541-
tableValue,
542-
queryValue,
586+
getValueOrNull(config.getDatabase()),
587+
getValueOrNull(config.getSchema()),
588+
getValueOrNull(table),
589+
getValueOrNull(query),
543590
storageIntegrationName.get(),
544591
stagingBucketRunDir,
545592
quotationMark.get());
546593

547-
String output = snowflakeServices.getBatchService().read(config);
594+
LOG.info("Running Snowflake COPY INTO stage: {}", stagingBucketRunDir);
595+
String globPattern = snowflakeServices.getBatchService().read(batchConfig);
596+
597+
List<MatchResult.Metadata> files = FileSystems.match(globPattern).metadata();
598+
LOG.info("Snowflake COPY INTO produced {} files", files.size());
599+
600+
return files.stream()
601+
.map(
602+
metadata ->
603+
new SnowflakeBoundedSource<T>(
604+
dataSourceProviderFn,
605+
query,
606+
table,
607+
storageIntegrationName,
608+
stagingBucketName,
609+
tmpDirName,
610+
snowflakeServices,
611+
quotationMark,
612+
csvMapper,
613+
coder,
614+
metadata.resourceId().toString(),
615+
metadata.sizeBytes()))
616+
.collect(Collectors.toList());
617+
}
548618

549-
context.output(output);
619+
@Override
620+
public long getEstimatedSizeBytes(PipelineOptions options) {
621+
return fileSize;
550622
}
551-
}
552623

553-
/**
554-
* Parses {@code String} from incoming data in {@link PCollection} to have proper format for CSV
555-
* files.
556-
*/
557-
public static class MapCsvToStringArrayFn extends DoFn<String, String[]> {
558-
private ValueProvider<String> quoteChar;
624+
@Override
625+
public BoundedReader<T> createReader(PipelineOptions options) throws IOException {
626+
if (filePath == null) {
627+
throw new IOException("Cannot create reader from unsplit parent source");
628+
}
629+
return new SnowflakeFileReader<>(this);
630+
}
559631

560-
public MapCsvToStringArrayFn(ValueProvider<String> quoteChar) {
561-
this.quoteChar = quoteChar;
632+
@Override
633+
public Coder<T> getOutputCoder() {
634+
return coder;
562635
}
563636

564-
@ProcessElement
565-
public void processElement(ProcessContext c) throws IOException {
566-
String csvLine = c.element();
567-
CSVParser parser = new CSVParserBuilder().withQuoteChar(quoteChar.get().charAt(0)).build();
568-
String[] parts = parser.parseLine(csvLine);
569-
c.output(parts);
637+
@Override
638+
public void validate() {
639+
// Validation is done in SnowflakeIO.Read.checkArguments()
570640
}
571-
}
572641

573-
private static class MapStringArrayToUserDataFn<T> extends DoFn<String[], T> {
574-
private final CsvMapper<T> csvMapper;
642+
private static class SnowflakeFileReader<T> extends BoundedReader<T> {
643+
private final SnowflakeBoundedSource<T> source;
644+
private transient BufferedReader reader;
645+
private transient CSVParser csvParser;
646+
private T current;
575647

576-
public MapStringArrayToUserDataFn(CsvMapper<T> csvMapper) {
577-
this.csvMapper = csvMapper;
578-
}
648+
SnowflakeFileReader(SnowflakeBoundedSource<T> source) {
649+
this.source = source;
650+
}
579651

580-
@ProcessElement
581-
public void processElement(ProcessContext context) throws Exception {
582-
context.output(csvMapper.mapRow(context.element()));
652+
@Override
653+
public boolean start() throws IOException {
654+
ResourceId resourceId = FileSystems.matchNewResource(source.filePath, false);
655+
ReadableByteChannel channel = FileSystems.open(resourceId);
656+
InputStream inputStream = new GZIPInputStream(Channels.newInputStream(channel));
657+
658+
reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8));
659+
csvParser =
660+
new CSVParserBuilder().withQuoteChar(source.quotationMark.get().charAt(0)).build();
661+
662+
return advance();
663+
}
664+
665+
@Override
666+
public boolean advance() throws IOException {
667+
String line = reader.readLine();
668+
if (line == null) {
669+
return false;
670+
}
671+
try {
672+
String[] parts = csvParser.parseLine(line);
673+
current = source.csvMapper.mapRow(parts);
674+
return true;
675+
} catch (Exception e) {
676+
throw new IOException("Error mapping CSV row: " + line, e);
677+
}
678+
}
679+
680+
@Override
681+
public T getCurrent() {
682+
return current;
683+
}
684+
685+
@Override
686+
public void close() throws IOException {
687+
if (reader != null) {
688+
reader.close();
689+
}
690+
}
691+
692+
@Override
693+
public BoundedSource<T> getCurrentSource() {
694+
return source;
695+
}
583696
}
584697
}
585698

0 commit comments

Comments
 (0)