|
17 | 17 | */ |
18 | 18 | package org.apache.beam.sdk.io.snowflake; |
19 | 19 |
|
20 | | -import static org.apache.beam.sdk.io.TextIO.readFiles; |
21 | 20 | import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; |
22 | 21 |
|
23 | 22 | import com.google.auto.value.AutoValue; |
24 | 23 | import com.opencsv.CSVParser; |
25 | 24 | import com.opencsv.CSVParserBuilder; |
| 25 | +import java.io.BufferedReader; |
26 | 26 | import java.io.IOException; |
| 27 | +import java.io.InputStream; |
| 28 | +import java.io.InputStreamReader; |
27 | 29 | import java.io.Serializable; |
| 30 | +import java.nio.channels.Channels; |
| 31 | +import java.nio.channels.ReadableByteChannel; |
| 32 | +import java.nio.charset.StandardCharsets; |
28 | 33 | import java.security.PrivateKey; |
29 | 34 | import java.sql.SQLException; |
30 | 35 | import java.time.LocalDateTime; |
31 | 36 | import java.time.ZoneId; |
32 | 37 | import java.util.ArrayList; |
| 38 | +import java.util.Collections; |
33 | 39 | import java.util.List; |
34 | 40 | import java.util.UUID; |
35 | 41 | import java.util.concurrent.ConcurrentHashMap; |
36 | 42 | import java.util.stream.Collectors; |
| 43 | +import java.util.zip.GZIPInputStream; |
37 | 44 | import javax.annotation.Nullable; |
38 | 45 | import javax.sql.DataSource; |
39 | 46 | import net.snowflake.client.api.datasource.SnowflakeDataSource; |
|
43 | 50 | import org.apache.beam.sdk.coders.Coder; |
44 | 51 | import org.apache.beam.sdk.coders.ListCoder; |
45 | 52 | import org.apache.beam.sdk.coders.StringUtf8Coder; |
| 53 | +import org.apache.beam.sdk.io.BoundedSource; |
46 | 54 | import org.apache.beam.sdk.io.Compression; |
47 | 55 | import org.apache.beam.sdk.io.FileIO; |
48 | 56 | import org.apache.beam.sdk.io.FileSystems; |
49 | 57 | import org.apache.beam.sdk.io.TextIO; |
50 | 58 | import org.apache.beam.sdk.io.WriteFilesResult; |
| 59 | +import org.apache.beam.sdk.io.fs.MatchResult; |
51 | 60 | import org.apache.beam.sdk.io.fs.MoveOptions; |
52 | 61 | import org.apache.beam.sdk.io.fs.ResourceId; |
53 | 62 | import org.apache.beam.sdk.io.snowflake.data.SnowflakeTableSchema; |
|
59 | 68 | import org.apache.beam.sdk.io.snowflake.services.SnowflakeServices; |
60 | 69 | import org.apache.beam.sdk.io.snowflake.services.SnowflakeServicesImpl; |
61 | 70 | import org.apache.beam.sdk.io.snowflake.services.SnowflakeStreamingServiceConfig; |
| 71 | +import org.apache.beam.sdk.options.PipelineOptions; |
62 | 72 | import org.apache.beam.sdk.options.ValueProvider; |
63 | 73 | import org.apache.beam.sdk.transforms.Combine; |
64 | 74 | import org.apache.beam.sdk.transforms.Create; |
|
67 | 77 | import org.apache.beam.sdk.transforms.PTransform; |
68 | 78 | import org.apache.beam.sdk.transforms.ParDo; |
69 | 79 | import org.apache.beam.sdk.transforms.Reify; |
70 | | -import org.apache.beam.sdk.transforms.Reshuffle; |
71 | 80 | import org.apache.beam.sdk.transforms.SerializableFunction; |
72 | 81 | import org.apache.beam.sdk.transforms.SimpleFunction; |
73 | 82 | import org.apache.beam.sdk.transforms.Values; |
@@ -424,31 +433,26 @@ public Read<T> withQuotationMark(ValueProvider<String> quotationMark) { |
424 | 433 | public PCollection<T> expand(PBegin input) { |
425 | 434 | checkArguments(); |
426 | 435 |
|
427 | | - PCollection<Void> emptyCollection = input.apply(Create.of((Void) null)); |
428 | 436 | String tmpDirName = makeTmpDirName(); |
429 | | - PCollection<T> output = |
430 | | - emptyCollection |
431 | | - .apply( |
432 | | - ParDo.of( |
433 | | - new CopyIntoStageFn( |
434 | | - getDataSourceProviderFn(), |
435 | | - getQuery(), |
436 | | - getTable(), |
437 | | - getStorageIntegrationName(), |
438 | | - getStagingBucketName(), |
439 | | - tmpDirName, |
440 | | - getSnowflakeServices(), |
441 | | - getQuotationMark()))) |
442 | | - .apply(Reshuffle.viaRandomKey()) |
443 | | - .apply(FileIO.matchAll()) |
444 | | - .apply(FileIO.readMatches()) |
445 | | - .apply(readFiles()) |
446 | | - .apply(ParDo.of(new MapCsvToStringArrayFn(getQuotationMark()))) |
447 | | - .apply(ParDo.of(new MapStringArrayToUserDataFn<>(getCsvMapper()))); |
448 | 437 |
|
| 438 | + SnowflakeBoundedSource<T> source = |
| 439 | + new SnowflakeBoundedSource<>( |
| 440 | + getDataSourceProviderFn(), |
| 441 | + getQuery(), |
| 442 | + getTable(), |
| 443 | + getStorageIntegrationName(), |
| 444 | + getStagingBucketName(), |
| 445 | + tmpDirName, |
| 446 | + getSnowflakeServices(), |
| 447 | + getQuotationMark(), |
| 448 | + getCsvMapper(), |
| 449 | + getCoder()); |
| 450 | + |
| 451 | + PCollection<T> output = input.apply(org.apache.beam.sdk.io.Read.from(source)); |
449 | 452 | output.setCoder(getCoder()); |
450 | 453 |
|
451 | | - emptyCollection |
| 454 | + input |
| 455 | + .apply(Create.of((Void) null)) |
452 | 456 | .apply(Wait.on(output)) |
453 | 457 | .apply(ParDo.of(new CleanTmpFilesFromGcsFn(getStagingBucketName(), tmpDirName))); |
454 | 458 | return output; |
@@ -483,103 +487,212 @@ private String makeTmpDirName() { |
483 | 487 | ); |
484 | 488 | } |
485 | 489 |
|
486 | | - private static class CopyIntoStageFn extends DoFn<Object, String> { |
| 490 | + /** |
| 491 | + * A {@link BoundedSource} that reads from Snowflake by running COPY INTO to stage CSV files, |
| 492 | + * then splitting into one sub-source per file. |
| 493 | + */ |
| 494 | + private static class SnowflakeBoundedSource<T> extends BoundedSource<T> { |
| 495 | + private static final Logger LOG = LoggerFactory.getLogger(SnowflakeBoundedSource.class); |
| 496 | + |
487 | 497 | private final SerializableFunction<Void, DataSource> dataSourceProviderFn; |
488 | | - private final ValueProvider<String> query; |
489 | | - private final ValueProvider<String> database; |
490 | | - private final ValueProvider<String> schema; |
491 | | - private final ValueProvider<String> table; |
| 498 | + private final @Nullable ValueProvider<String> query; |
| 499 | + private final @Nullable ValueProvider<String> table; |
492 | 500 | private final ValueProvider<String> storageIntegrationName; |
493 | | - private final ValueProvider<String> stagingBucketDir; |
| 501 | + private final ValueProvider<String> stagingBucketName; |
494 | 502 | private final String tmpDirName; |
495 | 503 | private final SnowflakeServices snowflakeServices; |
496 | 504 | private final ValueProvider<String> quotationMark; |
| 505 | + private final CsvMapper<T> csvMapper; |
| 506 | + private final Coder<T> coder; |
| 507 | + |
| 508 | + // Non-null only for child sources (one per staged file) |
| 509 | + private final @Nullable String filePath; |
| 510 | + private final long fileSize; |
497 | 511 |
|
498 | | - private CopyIntoStageFn( |
| 512 | + SnowflakeBoundedSource( |
499 | 513 | SerializableFunction<Void, DataSource> dataSourceProviderFn, |
500 | | - ValueProvider<String> query, |
501 | | - ValueProvider<String> table, |
| 514 | + @Nullable ValueProvider<String> query, |
| 515 | + @Nullable ValueProvider<String> table, |
502 | 516 | ValueProvider<String> storageIntegrationName, |
503 | | - ValueProvider<String> stagingBucketDir, |
| 517 | + ValueProvider<String> stagingBucketName, |
504 | 518 | String tmpDirName, |
505 | 519 | SnowflakeServices snowflakeServices, |
506 | | - ValueProvider<String> quotationMark) { |
| 520 | + ValueProvider<String> quotationMark, |
| 521 | + CsvMapper<T> csvMapper, |
| 522 | + Coder<T> coder) { |
| 523 | + this( |
| 524 | + dataSourceProviderFn, |
| 525 | + query, |
| 526 | + table, |
| 527 | + storageIntegrationName, |
| 528 | + stagingBucketName, |
| 529 | + tmpDirName, |
| 530 | + snowflakeServices, |
| 531 | + quotationMark, |
| 532 | + csvMapper, |
| 533 | + coder, |
| 534 | + null, |
| 535 | + 0); |
| 536 | + } |
| 537 | + |
| 538 | + private SnowflakeBoundedSource( |
| 539 | + SerializableFunction<Void, DataSource> dataSourceProviderFn, |
| 540 | + @Nullable ValueProvider<String> query, |
| 541 | + @Nullable ValueProvider<String> table, |
| 542 | + ValueProvider<String> storageIntegrationName, |
| 543 | + ValueProvider<String> stagingBucketName, |
| 544 | + String tmpDirName, |
| 545 | + SnowflakeServices snowflakeServices, |
| 546 | + ValueProvider<String> quotationMark, |
| 547 | + CsvMapper<T> csvMapper, |
| 548 | + Coder<T> coder, |
| 549 | + @Nullable String filePath, |
| 550 | + long fileSize) { |
507 | 551 | this.dataSourceProviderFn = dataSourceProviderFn; |
508 | 552 | this.query = query; |
509 | 553 | this.table = table; |
510 | 554 | this.storageIntegrationName = storageIntegrationName; |
| 555 | + this.stagingBucketName = stagingBucketName; |
| 556 | + this.tmpDirName = tmpDirName; |
511 | 557 | this.snowflakeServices = snowflakeServices; |
512 | 558 | this.quotationMark = quotationMark; |
513 | | - this.stagingBucketDir = stagingBucketDir; |
514 | | - this.tmpDirName = tmpDirName; |
515 | | - DataSourceProviderFromDataSourceConfiguration |
516 | | - dataSourceProviderFromDataSourceConfiguration = |
517 | | - (DataSourceProviderFromDataSourceConfiguration) this.dataSourceProviderFn; |
518 | | - DataSourceConfiguration config = dataSourceProviderFromDataSourceConfiguration.getConfig(); |
519 | | - |
520 | | - this.database = config.getDatabase(); |
521 | | - this.schema = config.getSchema(); |
| 559 | + this.csvMapper = csvMapper; |
| 560 | + this.coder = coder; |
| 561 | + this.filePath = filePath; |
| 562 | + this.fileSize = fileSize; |
522 | 563 | } |
523 | 564 |
|
524 | | - @ProcessElement |
525 | | - public void processElement(ProcessContext context) throws Exception { |
526 | | - String databaseValue = getValueOrNull(this.database); |
527 | | - String schemaValue = getValueOrNull(this.schema); |
528 | | - String tableValue = getValueOrNull(this.table); |
529 | | - String queryValue = getValueOrNull(this.query); |
| 565 | + @Override |
| 566 | + public List<? extends BoundedSource<T>> split( |
| 567 | + long desiredBundleSizeBytes, PipelineOptions options) throws Exception { |
| 568 | + if (filePath != null) { |
| 569 | + return Collections.singletonList(this); |
| 570 | + } |
530 | 571 |
|
531 | 572 | String stagingBucketRunDir = |
532 | 573 | String.format( |
533 | 574 | "%s/%s/run_%s/", |
534 | | - stagingBucketDir.get(), tmpDirName, UUID.randomUUID().toString().subSequence(0, 8)); |
| 575 | + stagingBucketName.get(), |
| 576 | + tmpDirName, |
| 577 | + UUID.randomUUID().toString().subSequence(0, 8)); |
535 | 578 |
|
536 | | - SnowflakeBatchServiceConfig config = |
| 579 | + DataSourceProviderFromDataSourceConfiguration dsProvider = |
| 580 | + (DataSourceProviderFromDataSourceConfiguration) dataSourceProviderFn; |
| 581 | + DataSourceConfiguration config = dsProvider.getConfig(); |
| 582 | + |
| 583 | + SnowflakeBatchServiceConfig batchConfig = |
537 | 584 | new SnowflakeBatchServiceConfig( |
538 | 585 | dataSourceProviderFn, |
539 | | - databaseValue, |
540 | | - schemaValue, |
541 | | - tableValue, |
542 | | - queryValue, |
| 586 | + getValueOrNull(config.getDatabase()), |
| 587 | + getValueOrNull(config.getSchema()), |
| 588 | + getValueOrNull(table), |
| 589 | + getValueOrNull(query), |
543 | 590 | storageIntegrationName.get(), |
544 | 591 | stagingBucketRunDir, |
545 | 592 | quotationMark.get()); |
546 | 593 |
|
547 | | - String output = snowflakeServices.getBatchService().read(config); |
| 594 | + LOG.info("Running Snowflake COPY INTO stage: {}", stagingBucketRunDir); |
| 595 | + String globPattern = snowflakeServices.getBatchService().read(batchConfig); |
| 596 | + |
| 597 | + List<MatchResult.Metadata> files = FileSystems.match(globPattern).metadata(); |
| 598 | + LOG.info("Snowflake COPY INTO produced {} files", files.size()); |
| 599 | + |
| 600 | + return files.stream() |
| 601 | + .map( |
| 602 | + metadata -> |
| 603 | + new SnowflakeBoundedSource<T>( |
| 604 | + dataSourceProviderFn, |
| 605 | + query, |
| 606 | + table, |
| 607 | + storageIntegrationName, |
| 608 | + stagingBucketName, |
| 609 | + tmpDirName, |
| 610 | + snowflakeServices, |
| 611 | + quotationMark, |
| 612 | + csvMapper, |
| 613 | + coder, |
| 614 | + metadata.resourceId().toString(), |
| 615 | + metadata.sizeBytes())) |
| 616 | + .collect(Collectors.toList()); |
| 617 | + } |
548 | 618 |
|
549 | | - context.output(output); |
| 619 | + @Override |
| 620 | + public long getEstimatedSizeBytes(PipelineOptions options) { |
| 621 | + return fileSize; |
550 | 622 | } |
551 | | - } |
552 | 623 |
|
553 | | - /** |
554 | | - * Parses {@code String} from incoming data in {@link PCollection} to have proper format for CSV |
555 | | - * files. |
556 | | - */ |
557 | | - public static class MapCsvToStringArrayFn extends DoFn<String, String[]> { |
558 | | - private ValueProvider<String> quoteChar; |
| 624 | + @Override |
| 625 | + public BoundedReader<T> createReader(PipelineOptions options) throws IOException { |
| 626 | + if (filePath == null) { |
| 627 | + throw new IOException("Cannot create reader from unsplit parent source"); |
| 628 | + } |
| 629 | + return new SnowflakeFileReader<>(this); |
| 630 | + } |
559 | 631 |
|
560 | | - public MapCsvToStringArrayFn(ValueProvider<String> quoteChar) { |
561 | | - this.quoteChar = quoteChar; |
| 632 | + @Override |
| 633 | + public Coder<T> getOutputCoder() { |
| 634 | + return coder; |
562 | 635 | } |
563 | 636 |
|
564 | | - @ProcessElement |
565 | | - public void processElement(ProcessContext c) throws IOException { |
566 | | - String csvLine = c.element(); |
567 | | - CSVParser parser = new CSVParserBuilder().withQuoteChar(quoteChar.get().charAt(0)).build(); |
568 | | - String[] parts = parser.parseLine(csvLine); |
569 | | - c.output(parts); |
| 637 | + @Override |
| 638 | + public void validate() { |
| 639 | + // Validation is done in SnowflakeIO.Read.checkArguments() |
570 | 640 | } |
571 | | - } |
572 | 641 |
|
573 | | - private static class MapStringArrayToUserDataFn<T> extends DoFn<String[], T> { |
574 | | - private final CsvMapper<T> csvMapper; |
| 642 | + private static class SnowflakeFileReader<T> extends BoundedReader<T> { |
| 643 | + private final SnowflakeBoundedSource<T> source; |
| 644 | + private transient BufferedReader reader; |
| 645 | + private transient CSVParser csvParser; |
| 646 | + private T current; |
575 | 647 |
|
576 | | - public MapStringArrayToUserDataFn(CsvMapper<T> csvMapper) { |
577 | | - this.csvMapper = csvMapper; |
578 | | - } |
| 648 | + SnowflakeFileReader(SnowflakeBoundedSource<T> source) { |
| 649 | + this.source = source; |
| 650 | + } |
579 | 651 |
|
580 | | - @ProcessElement |
581 | | - public void processElement(ProcessContext context) throws Exception { |
582 | | - context.output(csvMapper.mapRow(context.element())); |
| 652 | + @Override |
| 653 | + public boolean start() throws IOException { |
| 654 | + ResourceId resourceId = FileSystems.matchNewResource(source.filePath, false); |
| 655 | + ReadableByteChannel channel = FileSystems.open(resourceId); |
| 656 | + InputStream inputStream = new GZIPInputStream(Channels.newInputStream(channel)); |
| 657 | + |
| 658 | + reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)); |
| 659 | + csvParser = |
| 660 | + new CSVParserBuilder().withQuoteChar(source.quotationMark.get().charAt(0)).build(); |
| 661 | + |
| 662 | + return advance(); |
| 663 | + } |
| 664 | + |
| 665 | + @Override |
| 666 | + public boolean advance() throws IOException { |
| 667 | + String line = reader.readLine(); |
| 668 | + if (line == null) { |
| 669 | + return false; |
| 670 | + } |
| 671 | + try { |
| 672 | + String[] parts = csvParser.parseLine(line); |
| 673 | + current = source.csvMapper.mapRow(parts); |
| 674 | + return true; |
| 675 | + } catch (Exception e) { |
| 676 | + throw new IOException("Error mapping CSV row: " + line, e); |
| 677 | + } |
| 678 | + } |
| 679 | + |
| 680 | + @Override |
| 681 | + public T getCurrent() { |
| 682 | + return current; |
| 683 | + } |
| 684 | + |
| 685 | + @Override |
| 686 | + public void close() throws IOException { |
| 687 | + if (reader != null) { |
| 688 | + reader.close(); |
| 689 | + } |
| 690 | + } |
| 691 | + |
| 692 | + @Override |
| 693 | + public BoundedSource<T> getCurrentSource() { |
| 694 | + return source; |
| 695 | + } |
583 | 696 | } |
584 | 697 | } |
585 | 698 |
|
|
0 commit comments