diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c54eb3a..46ffb62 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,11 +19,11 @@ jobs: with: fetch-depth: 0 - - name: Install JDK 11 + - name: Install JDK 21 uses: actions/setup-java@v4 with: distribution: zulu - java-version: '11' + java-version: '21' java-package: jdk - name: Install SBT diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 75096a0..61ffc05 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -11,11 +11,11 @@ jobs: - name: Checkout uses: actions/checkout@v4 - - name: Install JDK 11 + - name: Install JDK 21 uses: actions/setup-java@v4 with: distribution: zulu - java-version: '11' + java-version: '21' java-package: jdk - name: Install SBT diff --git a/.jvmopts b/.jvmopts new file mode 100644 index 0000000..48f70c2 --- /dev/null +++ b/.jvmopts @@ -0,0 +1,2 @@ +--enable-preview +--add-modules=jdk.incubator.vector diff --git a/.scalafmt.conf b/.scalafmt.conf index 6ea0481..cf7cb52 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -26,7 +26,7 @@ rewrite.rules = [ Imports ] rewrite.imports.expand = true rewrite.imports.sort = ascii rewrite.imports.groups = [ - ["(?!javax?\\.|scala\\.).+"], + ["(?!javax?\\.|jdk\\.|scala\\.).+"], ] rewrite.trailingCommas.style = always spaces.neverAroundInfixTypes = [ "##" ] diff --git a/LICENSE b/LICENSE index 3aa514a..f6d09fa 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ The MIT License (MIT) -Copyright (c) 2022-2022 Sam Guymer +Copyright (c) 2022-2025 Sam Guymer Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in diff --git a/README.md b/README.md index 74ea53e..8f13448 100644 --- a/README.md +++ b/README.md @@ -14,9 +14,8 @@ case class Test( int: Int, bool: Boolean, optInt: Option[Int], -) +) derives CsvRecordDecoder object Test { - implicit val decoder: CsvRecordDecoder[Test] = CsvRecordDecoder.derive val header = ::("str", List("int", "bool", "opt_int")) val csvHeader = CsvHeader.create(header)(decoder) } diff --git a/benchmark/README.md b/benchmark/README.md index ae18d41..a24d59b 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -1,26 +1,30 @@ `benchmark/Jmh/run -i 10 -wi 5 -f 1 -t 2 ceesvee.benchmark.ParserBenchmark` +AMD Ryzen 9 9950X ``` # JMH version: 1.37 # VM version: JDK 25.0.1, OpenJDK 64-Bit Server VM, 25.0.1 -Benchmark Mode Cnt Score Error Units -ParserBenchmark.ceesvee avgt 10 261.357 ± 1.787 us/op -ParserBenchmark.scalaCsv avgt 10 741.778 ± 6.433 us/op -ParserBenchmark.univocity avgt 10 200.482 ± 2.715 us/op +Benchmark Mode Cnt Score Error Units +ParserBenchmark.ceesvee avgt 10 263.230 ± 0.679 us/op +ParserBenchmark.ceesveeVector avgt 10 134.205 ± 0.302 us/op +ParserBenchmark.scalaCsv avgt 10 748.232 ± 2.016 us/op +ParserBenchmark.univocity avgt 10 198.765 ± 0.982 us/op ``` ``` # JMH version: 1.37 # VM version: JDK 25, OpenJDK 64-Bit Server VM, 25+37-jvmci-b01 -Benchmark Mode Cnt Score Error Units -ParserBenchmark.ceesvee avgt 10 197.994 ± 2.344 us/op -ParserBenchmark.scalaCsv avgt 10 776.080 ± 1.457 us/op -ParserBenchmark.univocity avgt 10 208.226 ± 2.501 us/op +Benchmark Mode Cnt Score Error Units +ParserBenchmark.ceesvee avgt 10 187.441 ± 1.345 us/op +ParserBenchmark.ceesveeVector avgt 10 1484.755 ± 14.298 us/op +ParserBenchmark.scalaCsv avgt 10 780.945 ± 2.340 us/op +ParserBenchmark.univocity avgt 10 204.178 ± 1.702 us/op ``` `benchmark/Jmh/run -i 10 -wi 5 -f 1 -t 2 ceesvee.benchmark.DecoderBenchmark` +AMD Ryzen 9 9950X ``` # JMH version: 1.37 # VM version: JDK 25.0.1, OpenJDK 64-Bit Server VM, 25.0.1 diff --git a/benchmark/src/main/scala/ceesvee/benchmark/ParserBenchmark.scala b/benchmark/src/main/scala/ceesvee/benchmark/ParserBenchmark.scala index fb4018b..19fb021 100644 --- a/benchmark/src/main/scala/ceesvee/benchmark/ParserBenchmark.scala +++ b/benchmark/src/main/scala/ceesvee/benchmark/ParserBenchmark.scala @@ -11,16 +11,23 @@ import java.util.concurrent.TimeUnit @State(Scope.Thread) @BenchmarkMode(Array(Mode.AverageTime)) @OutputTimeUnit(TimeUnit.MICROSECONDS) +@Fork( + jvmArgs = Array( + "--enable-preview", + "--add-modules=jdk.incubator.vector", + ), +) class ParserBenchmark { private def line(i: Int) = List("basic string", " \"quoted \nstring\" ", i.toString, "456.789", "true").mkString(",") + private val charset = StandardCharsets.UTF_8 private val lines = (1 to 1000).map(line(_)).mkString("\n") - private def linesChunked = lines.grouped(8192) + private val linesBytes = lines.getBytes(charset) private def linesReader = { val streams = new java.util.ArrayList[ByteArrayInputStream]() - linesChunked.foreach { str => - streams.add(new ByteArrayInputStream(str.getBytes(StandardCharsets.UTF_8))) + linesBytes.grouped(8192).foreach { bytes => + streams.add(new ByteArrayInputStream(bytes)) } val is = new SequenceInputStream(java.util.Collections.enumeration(streams)) new InputStreamReader(is) @@ -32,7 +39,12 @@ class ParserBenchmark { @Benchmark def ceesvee: List[List[String]] = { - _root_.ceesvee.CsvParser.parse[List](linesChunked, ceesveeOptions).toList + _root_.ceesvee.CsvParser.parse[List](lines.grouped(8192), ceesveeOptions).toList + } + + @Benchmark + def ceesveeVector: List[List[String]] = { + _root_.ceesvee.CsvParserVector.parse[List](linesBytes.grouped(8192), charset, ceesveeOptions).toList } @Benchmark diff --git a/build.sbt b/build.sbt index 18a7008..59c9416 100644 --- a/build.sbt +++ b/build.sbt @@ -26,7 +26,7 @@ lazy val commonSettings = Seq( "-deprecation", "-encoding", "UTF-8", "-feature", - "-release", "11", + "-release", "21", "-unchecked", ), scalacOptions ++= (CrossVersion.partialVersion(scalaVersion.value) match { @@ -47,6 +47,7 @@ lazy val commonSettings = Seq( case Some((2, _)) => Seq( "-Vimplicits", "-Vtype-diffs", + "-Wconf:cat=scala3-migration:silent", "-Wdead-code", "-Wextra-implicit", "-Wnonunit-statement", @@ -59,15 +60,12 @@ lazy val commonSettings = Seq( "-Xlint:_,-byname-implicit", // exclude byname-implicit https://github.com/scala/bug/issues/12072 ) case _ => Seq( + "-Wconf:name=PatternMatchExhaustivity:error", "-Wnonunit-statement", "-Wunused:all", "-Wvalue-discard", ) }), - Test / scalacOptions ++= (CrossVersion.partialVersion(scalaVersion.value) match { - case Some((2, _)) => Seq("-Wconf:cat=scala3-migration:silent") - case _ => Seq.empty - }), Compile / console / scalacOptions ~= filterScalacConsoleOpts, Test / console / scalacOptions ~= filterScalacConsoleOpts, diff --git a/docs/architecture.md b/docs/architecture.md new file mode 100644 index 0000000..ca0e117 --- /dev/null +++ b/docs/architecture.md @@ -0,0 +1,9 @@ +## Architecture + +### Module Structure + +- **core** - Main CSV parsing, encoding/decoding logic. Only dependency is an optional one on [cats](https://github.com/typelevel/cats) +- **fs2** - Integration with [fs2](https://github.com/typelevel/fs2) streams +- **zio** - Integration with [ZIO](https://github.com/zio/zio) streams +- **benchmark** - JMH performance benchmarks comparing against other CSV libraries +- **tests** - Integration tests with real-world CSV files diff --git a/modules/core/src/main/scala/ceesvee/CsvParser.scala b/modules/core/src/main/scala/ceesvee/CsvParser.scala index e590aee..1e3eb69 100644 --- a/modules/core/src/main/scala/ceesvee/CsvParser.scala +++ b/modules/core/src/main/scala/ceesvee/CsvParser.scala @@ -331,7 +331,7 @@ object CsvParser { fields.result() } - private def trimString(options: Options, str: String) = { + private[ceesvee] def trimString(options: Options, str: String) = { // always ignore whitespace around a quoted cell val trimmed = Options.Trim.True.strip(str) diff --git a/modules/core/src/main/scala/ceesvee/CsvParserVector.scala b/modules/core/src/main/scala/ceesvee/CsvParserVector.scala new file mode 100644 index 0000000..4f64925 --- /dev/null +++ b/modules/core/src/main/scala/ceesvee/CsvParserVector.scala @@ -0,0 +1,359 @@ +package ceesvee + +import java.nio.charset.Charset +import jdk.incubator.vector.ByteVector +import jdk.incubator.vector.VectorMask +import scala.annotation.tailrec +import scala.collection.Factory +import scala.collection.mutable + +@SuppressWarnings(Array( + "org.wartremover.warts.MutableDataStructures", + "org.wartremover.warts.Throw", + "org.wartremover.warts.Var", + "org.wartremover.warts.While", +)) +object CsvParserVector { + import CsvParser.Error + import CsvParser.Options + import CsvParser.ignoreTrimmedLine + + /** + * @see + * [[CsvParser.parse]] + */ + @throws[Error.LineTooLong]("if a line is longer than `maximumLineLength`") + def parse[C[_]]( + in: Iterator[Array[Byte]], + charset: Charset, + options: Options, + )(implicit f: Factory[String, C[String]]): Iterator[C[String]] = { + splitLines(in, options) + .map(parseLine(_, charset, options)) + .filter(fields => fields != null) + } + + /** + * Splits the given byte arrays into CSV lines using the Vector API by + * splitting on either '\r\n' or '\n'. + * + * '"' is the only valid escape for nested double quotes. + */ + @throws[Error.LineTooLong]("if a line is longer than `maximumLineLength`") + private def splitLines(in: Iterator[Array[Byte]], options: Options): Iterator[Array[Byte]] = new SplitLinesVectorIterator(in, options) + private final class SplitLinesVectorIterator(in: Iterator[Array[Byte]], options: Options) extends Iterator[Array[Byte]] { + private val toOutput = mutable.Queue.empty[Array[Byte]] + private var state = State.initial + + override def hasNext = toOutput.nonEmpty || in.hasNext || state.leftover.nonEmpty + + @tailrec override def next() = { + if (toOutput.nonEmpty) { + toOutput.dequeue() + } else { + val leftover = state.leftover + if (leftover.length > options.maximumLineLength) { + throw Error.LineTooLong(options.maximumLineLength) + } + + if (!in.hasNext) { + state = State.initial + leftover + } else { + val bytes = in.next() + val (newState, lines) = splitBytes(bytes, state) + val _ = toOutput.enqueueAll(lines) + state = newState + next() + } + } + } + } + + private[ceesvee] class State( + val leftover: Array[Byte], + val insideQuote: Boolean, + val prevCarriageReturn: Boolean, + ) + private[ceesvee] object State { + val initial: State = new State( + leftover = Array.emptyByteArray, + insideQuote = false, + prevCarriageReturn = false, + ) + } + + private val Quote: Byte = '"' + private val Comma: Byte = ',' + private val NewLine: Byte = '\n' + private val CarriageReturn: Byte = '\r' + + private val ByteVectorSpecies = ByteVector.SPECIES_PREFERRED + + private[ceesvee] def splitBytes[C[S] <: Iterable[S]]( + bytes: Array[Byte], + state: State, + )(implicit f: Factory[Array[Byte], C[Array[Byte]]]): (State, C[Array[Byte]]) = { + + val builder = f.newBuilder + var insideQuote = state.insideQuote + var prevCarriageReturn = state.prevCarriageReturn + var sliceStart = 0 + + val loopBound = ByteVectorSpecies.loopBound(bytes.length) + + var i = 0 + while (i < bytes.length) { + val vector = if (i >= loopBound) { + val m = ByteVectorSpecies.indexInRange(i, bytes.length) + ByteVector.fromArray(ByteVectorSpecies, bytes, i, m) + } else { + ByteVector.fromArray(ByteVectorSpecies, bytes, i) + } + + val quotes = vector.eq(Quote) + var mask = quotes + var betweenQuotes = 0L + + // set all bits between quotes + var quoteStart = if (insideQuote) 0 else -1 + while (mask.anyTrue()) { + val r = mask.firstTrue() + mask = mask.xor(VectorMask.fromLong(vector.species(), 1L << r)) + + if (quoteStart >= 0) { + var j = r - 1 + while (j >= quoteStart) { + betweenQuotes = betweenQuotes | (1L << j) + j = j - 1 + } + quoteStart = -1 + } else { + quoteStart = r + } + } + if (quoteStart >= 0) { + var j = ByteVectorSpecies.length - 1 + while (j > quoteStart) { + betweenQuotes = betweenQuotes | (1L << j) + j = j - 1 + } + } + + val quoteMask = quotes.or(VectorMask.fromLong(vector.species(), betweenQuotes)) + insideQuote = (quotes.trueCount() + (if (insideQuote) 1 else 0)) % 2 == 1 + + val crChars = vector.eq(CarriageReturn) + val crIgnoringWithinQuotes = crChars.andNot(crChars.and(quoteMask)) + val nlChars = vector.eq(NewLine) + val nlIgnoringWithinQuotes = nlChars.andNot(nlChars.and(quoteMask)) + + /* \ = \r, | = \n + a\|b\c|"d\|e"|f + 000000010000100 = quoteChars + 000000011111100 = quoteMask + 010010000100000 = crChars + 010010000000000 = crIgnoringWithinQuotes + 001000100010010 = nlChars + 001000100000010 = nlIgnoringWithinQuotes + */ + + var nlIgnoringWithinQuotesBitSet = nlIgnoringWithinQuotes.toLong + while (java.lang.Long.bitCount(nlIgnoringWithinQuotesBitSet) > 0) { + val r = java.lang.Long.numberOfTrailingZeros(nlIgnoringWithinQuotesBitSet) + nlIgnoringWithinQuotesBitSet = nlIgnoringWithinQuotesBitSet ^ java.lang.Long.lowestOneBit(nlIgnoringWithinQuotesBitSet) + + val isPrevCr = + (r == 0 && prevCarriageReturn) || + (r > 0 && crIgnoringWithinQuotes.laneIsSet(r - 1)) + + val sliceTo = i + r + val leftover = if (sliceStart == 0) state.leftover else Array.emptyByteArray + val to = if (isPrevCr) sliceTo - 1 else sliceTo + val _ = builder += leftover ++ arraySlice(bytes, sliceStart, to, i, 0) + + sliceStart = sliceTo + 1 + } + + prevCarriageReturn = crIgnoringWithinQuotes.laneIsSet(vector.length() - 1) + i = i + ByteVectorSpecies.length + } + + val leftover = if (sliceStart == 0) { + state.leftover ++ bytes + } else { + bytes.slice(sliceStart, bytes.length) + } + + (new State(leftover, insideQuote = insideQuote, prevCarriageReturn = prevCarriageReturn), builder.result()) + } + + /** + * Parse a line into a collection of CSV fields. + */ + @SuppressWarnings(Array("org.wartremover.warts.AsInstanceOf", "org.wartremover.warts.Null")) + private[ceesvee] def parseLine[C[_]]( + bytes: Array[Byte], + charset: Charset, + options: Options, + )(implicit f: Factory[String, C[String]]): C[String] = { + + val builder = f.newBuilder + var builderEmpty = true + var insideQuote = false + var sliceStart = 0 + + val loopBound = ByteVectorSpecies.loopBound(bytes.length) + + var i = 0 + while (i < bytes.length) { + val vector = if (i >= loopBound) { + val m = ByteVectorSpecies.indexInRange(i, bytes.length) + ByteVector.fromArray(ByteVectorSpecies, bytes, i, m) + } else { + ByteVector.fromArray(ByteVectorSpecies, bytes, i) + } + + val quotes = vector.eq(Quote) + var mask = quotes + var betweenQuotes = 0L + + // set all bits between quotes + var quoteStart = if (insideQuote) 0 else -1 + while (mask.anyTrue()) { + val r = mask.firstTrue() + mask = mask.xor(VectorMask.fromLong(vector.species(), 1L << r)) + + if (quoteStart >= 0) { + var j = r - 1 + while (j >= quoteStart) { + betweenQuotes = betweenQuotes | (1L << j) + j = j - 1 + } + quoteStart = -1 + } else { + quoteStart = r + } + } + if (quoteStart >= 0) { + var j = ByteVectorSpecies.length - 1 + while (j > quoteStart) { + betweenQuotes = betweenQuotes | (1L << j) + j = j - 1 + } + } + + val quoteMask = quotes.or(VectorMask.fromLong(vector.species(), betweenQuotes)) + insideQuote = (quotes.trueCount() + (if (insideQuote) 1 else 0)) % 2 == 1 + + val commaChars = vector.eq(Comma) + val commaIgnoringWithinQuotes = commaChars.andNot(commaChars.and(quoteMask)) + + /* + a,"b""c","d,e","",f + 0010110101000101100 = quoteChars + 0100000010010010010 = commaChars + 0100000010000010010 = commaIgnoringWithinQuotes + */ + + var commaIgnoringWithinQuotesBitSet = commaIgnoringWithinQuotes.toLong + while (java.lang.Long.bitCount(commaIgnoringWithinQuotesBitSet) > 0) { + val r = java.lang.Long.numberOfTrailingZeros(commaIgnoringWithinQuotesBitSet) + commaIgnoringWithinQuotesBitSet = commaIgnoringWithinQuotesBitSet ^ java.lang.Long.lowestOneBit(commaIgnoringWithinQuotesBitSet) + + val sliceTo = i + r + val str = handleField(arraySlice(bytes, sliceStart, sliceTo, i, 0), charset, options) + if (builderEmpty && ignoreTrimmedLine(str, options)) { + i = bytes.length + commaIgnoringWithinQuotesBitSet = 0 + } else { + val _ = builder += str + builderEmpty = false + sliceStart = sliceTo + 1 + } + } + + i = i + ByteVectorSpecies.length + } + + val remaining = if (sliceStart == 0) { + bytes + } else { + bytes.slice(sliceStart, bytes.length) + } + + val str = handleField(remaining, charset, options) + if (builderEmpty && ignoreTrimmedLine(str, options)) { + () + } else { + val _ = builder += str + builderEmpty = false + } + + if (builderEmpty) null.asInstanceOf[C[String]] else builder.result() + } + + private def handleField(bytes: Array[Byte], charset: Charset, options: Options) = { + val str = new String(bytes, charset) + val s = CsvParser.trimString(options, str) + s.replace("\"\"", "\"") + } + + private def arraySlice(src: Array[Byte], from: Int, to: Int, offset: Int, ignore: Long) = { + var from_ = from + var to_ = to + var ignoreCount = 0 + + var ignoreBitsSet = ignore + while (java.lang.Long.bitCount(ignoreBitsSet) > 0) { + val i = java.lang.Long.numberOfTrailingZeros(ignoreBitsSet) + offset + ignoreBitsSet = ignoreBitsSet ^ java.lang.Long.lowestOneBit(ignoreBitsSet) + + if (i < from_ || i > to_) { + () + } else if (i == from_) { + from_ = from_ + 1 + } else if (i == to_) { + to_ = to_ - 1 + } else { + ignoreCount = ignoreCount + 1 + } + } + + val size = to_ - from_ + ignoreCount + if (size <= 0 || from_ >= to_) { + Array.emptyByteArray + } else { + val dest = Array.ofDim[Byte](size) + + if (ignoreCount == 0) { + System.arraycopy(src, from_, dest, 0, size) + } else { + var srcPosition = from_ + var destPosition = 0 + + var ignoreBitsSet2 = ignore + while (java.lang.Long.bitCount(ignoreBitsSet2) > 0) { + val i = java.lang.Long.numberOfTrailingZeros(ignoreBitsSet2) + offset + ignoreBitsSet2 = ignoreBitsSet2 ^ java.lang.Long.lowestOneBit(ignoreBitsSet2) + + if (i < from_ || i > to_) { + () + } else if (srcPosition == i) { + srcPosition = i + 1 + } else { + val positionSize = srcPosition - i + System.arraycopy(src, srcPosition, dest, destPosition, positionSize) + srcPosition = i + 1 + destPosition = destPosition + positionSize + } + } + if (srcPosition < to_) { + System.arraycopy(src, srcPosition, dest, destPosition, to_ - srcPosition) + } + } + + dest + } + } +} diff --git a/modules/core/src/main/scala/ceesvee/package.scala b/modules/core/src/main/scala/ceesvee/package.scala new file mode 100644 index 0000000..31d34d4 --- /dev/null +++ b/modules/core/src/main/scala/ceesvee/package.scala @@ -0,0 +1,4 @@ +package object ceesvee { + + def VectorAPIAvailable: Boolean = ModuleLayer.boot().findModule("jdk.incubator.vector").isPresent +} diff --git a/modules/core/src/test/scala/ceesvee/CsvParserSpec.scala b/modules/core/src/test/scala/ceesvee/CsvParserSpec.scala index b0491fa..db419a9 100644 --- a/modules/core/src/test/scala/ceesvee/CsvParserSpec.scala +++ b/modules/core/src/test/scala/ceesvee/CsvParserSpec.scala @@ -5,10 +5,120 @@ import zio.ZIO import zio.test.ZIOSpecDefault import zio.test.assertTrue -object CsvParserSpec extends ZIOSpecDefault with CsvParserParserSuite { +object CsvParserSpec extends ZIOSpecDefault + with CsvParserParserSuite + with CsvSplitStringsSuite[CsvParser.State] + with CsvParserLineSuite { override val spec = suite("CsvParser")( parserSuite, + splitStringsSuite, + parseLineSuite, + ) + + override protected def parse(lines: Iterable[String], options: CsvParser.Options) = { + val input = lines.mkString("\n").grouped(8192) + val result = CsvParser.parse[List](input, options) + ZIO.succeed(Chunk.fromIterator(result)) + } + + override protected def parseLine(line: String, options: CsvParser.Options) = { + CsvParser.parseLine[List](line, options) + } + + override protected def splitStrings(strings: List[String], state: CsvParser.State) = CsvParser.splitStrings(strings, state) + + override protected def initialState = CsvParser.State.initial + override protected def stateLeftover(s: CsvParser.State) = s.leftover +} + +trait CsvParserParserSuite { self: ZIOSpecDefault => + + protected def parse( + lines: Iterable[String], + options: CsvParser.Options, + ): ZIO[Any, Throwable, Chunk[List[String]]] + + protected def parserSuite = suite("parser")( + test("lots") { + def line(i: Int) = List("basic string", " \"quoted \nstring\" ", i.toString, "456.789", "true").mkString(",") + + val lines = (1 to 10).map(line(_)) + parse(lines, CsvParser.Options.Defaults).map { result => + assertTrue(result.length == 10) + } + }, + suite("comment prefix")({ + val lines = List( + "a,b,c", + "#a,b,c", + "#", + " #", + "d,e,f", + ) + + test("no comments") { + val opts = CsvParser.Options.Defaults.copy(commentPrefix = None, trim = CsvParser.Options.Trim.False) + parse(lines, opts).map { result => + assertTrue(result == Chunk( + List("a", "b", "c"), + List("#a", "b", "c"), + List("#"), + List(" #"), + List("d", "e", "f"), + )) + } + } :: + test("false") { + val opts = CsvParser.Options.Defaults.copy(commentPrefix = Some("#")) + parse(lines, opts).map { result => + assertTrue(result == Chunk( + List("a", "b", "c"), + List("d", "e", "f"), + )) + } + } :: Nil + }), + suite("skip blank rows")({ + val lines = List( + "a,b,c", + "", + " ", + "d,e,f", + ) + + test("true") { + val opts = CsvParser.Options.Defaults.copy(skipBlankRows = true) + parse(lines, opts).map { result => + assertTrue(result == Chunk( + List("a", "b", "c"), + List("d", "e", "f"), + )) + } + } :: + test("false") { + val opts = CsvParser.Options.Defaults.copy(skipBlankRows = false) + parse(lines, opts).map { result => + assertTrue(result == Chunk( + List("a", "b", "c"), + List(""), + List(""), + List("d", "e", "f"), + )) + } + } :: Nil + }), + ) +} + +trait CsvSplitStringsSuite[S] { self: ZIOSpecDefault => + + protected def splitStrings(strings: List[String], state: S): (S, List[String]) + + protected def initialState: S + protected def stateLeftover(s: S): String + + protected def splitStringsSuite = { suite("split strings")( test("trailing new lines") { val strings = List( @@ -18,9 +128,9 @@ object CsvParserSpec extends ZIOSpecDefault with CsvParserParserSuite { "jkl", "\nmno", ) - val (state, lines) = CsvParser.splitStrings(strings, CsvParser.State.initial) + val (state, lines) = splitStrings(strings, initialState) assertTrue(lines == List("abc\rdef", "ghi", "jkl")) && - assertTrue(state.leftover == "mno") + assertTrue(stateLeftover(state) == "mno") }, test("trailing new lines aligned to vector boundary") { val strings = List( @@ -30,13 +140,13 @@ object CsvParserSpec extends ZIOSpecDefault with CsvParserParserSuite { "012345678901234567890123456789012345678901234567890123456789abcd", "\nmno", ) - val (state, lines) = CsvParser.splitStrings(strings, CsvParser.State.initial) + val (state, lines) = splitStrings(strings, initialState) assertTrue(lines == List( "012345678901234567890123456789012345678901234567890123456789abc\r012345678901234567890123456789012345678901234567890123456789abc", "012345678901234567890123456789012345678901234567890123456789ab", "012345678901234567890123456789012345678901234567890123456789abcd", )) && - assertTrue(state.leftover == "mno") + assertTrue(stateLeftover(state) == "mno") }, test("trailing double quotes") { val strings = List( @@ -45,33 +155,46 @@ object CsvParserSpec extends ZIOSpecDefault with CsvParserParserSuite { "\"", "\nfg\"", ) - val (state, lines) = CsvParser.splitStrings(strings, CsvParser.State.initial) + val (state, lines) = splitStrings(strings, initialState) val strings2 = List( "\n\"\"\"", "\n\"hi\"\"", ) - val (state2, lines2) = CsvParser.splitStrings(strings2, state) + val (state2, lines2) = splitStrings(strings2, state) val strings3 = List( "j\"", "\nmno", ) - val (state3, lines3) = CsvParser.splitStrings(strings3, state2) + val (state3, lines3) = splitStrings(strings3, state2) assertTrue( lines == List("""a,"b",c,"d""e","""""), - state.insideQuoteIndex == 2, - state.leftover == "fg\"", + stateLeftover(state) == "fg\"", ) && assertTrue( lines2 == List("fg\"\n\"\"\""), - state2.insideQuoteIndex == 0, - state2.leftover == "\"hi\"\"", + stateLeftover(state2) == "\"hi\"\"", ) && assertTrue( lines3 == List("\"hi\"\"j\""), - state3.insideQuoteIndex == -9, - state3.leftover == "mno", + stateLeftover(state3) == "mno", ) }, + test("trailing double quotes aligned to vector boundary") { + val strings = List( + "\"012345678901234567890123456789012345678901234567890123456789ab\"", + "\n\"012345678901234567890123456789012345678901234567890123456789\n\"", + "\"\n012345678901234567890123456789012345678901234567890123456789a\"", + "\n\"012345678901234567890123456789012345678901234567890123456789\"\"", + "\n0123456789\"\nmno", + ) + val (state, lines) = splitStrings(strings, initialState) + assertTrue(lines == List( + "\"012345678901234567890123456789012345678901234567890123456789ab\"", + "\"012345678901234567890123456789012345678901234567890123456789\n\"\"\n012345678901234567890123456789012345678901234567890123456789a\"", + "\"012345678901234567890123456789012345678901234567890123456789\"\"\n0123456789\"", + )) && + assertTrue(stateLeftover(state) == "mno") + }, test("quotes and new lines") { val strings = List( "a\"b\"c\n", @@ -80,28 +203,31 @@ object CsvParserSpec extends ZIOSpecDefault with CsvParserParserSuite { "\"jkl\"", "\nnmno", ) - val (state, lines) = CsvParser.splitStrings(strings, CsvParser.State.initial) + val (state, lines) = splitStrings(strings, initialState) assertTrue(lines == List( "a\"b\"c", "d\"\ne\r\nf\"", "g\"hi\r\"\"jkl\"", )) && - assertTrue(state.leftover == "nmno") + assertTrue(stateLeftover(state) == "nmno") }, // TODO property based tests - ), - parseLineSuite, - ) + ) + } +} - private def parseLineSuite = { - import CsvParser.parseLine +trait CsvParserLineSuite { self: ZIOSpecDefault => + + protected def parseLine(line: String, options: CsvParser.Options): List[String] + + protected def parseLineSuite = { import CsvParser.Options suite("parse line")( suite("escape character")( test("double quote") { val line = """a,"b""c",d,e"f""" - assertTrue(parseLine[List](line, Options.Defaults) == List("a", """b"c""", "d", "e\"f")) + assertTrue(parseLine(line, Options.Defaults) == List("a", """b"c""", "d", "e\"f")) }, ), suite("trim")({ @@ -109,116 +235,31 @@ object CsvParserSpec extends ZIOSpecDefault with CsvParserParserSuite { test("true") { val opts = Options.Defaults.copy(trim = Options.Trim.True) - assertTrue(parseLine[List](line, opts) == List("abc", "def", "ghi", "jkl", " mno ", "")) + assertTrue(parseLine(line, opts) == List("abc", "def", "ghi", "jkl", " mno ", "")) } :: test("false") { val opts = Options.Defaults.copy(trim = Options.Trim.False) - assertTrue(parseLine[List](line, opts) == List("abc", " def", "ghi ", " jkl ", " mno ", " ")) + assertTrue(parseLine(line, opts) == List("abc", " def", "ghi ", " jkl ", " mno ", " ")) } :: test("start") { val opts = Options.Defaults.copy(trim = Options.Trim.Start) - assertTrue(parseLine[List](line, opts) == List("abc", "def", "ghi ", "jkl ", " mno ", "")) + assertTrue(parseLine(line, opts) == List("abc", "def", "ghi ", "jkl ", " mno ", "")) } :: test("end") { val opts = Options.Defaults.copy(trim = Options.Trim.End) - assertTrue(parseLine[List](line, opts) == List("abc", " def", "ghi", " jkl", " mno ", "")) + assertTrue(parseLine(line, opts) == List("abc", " def", "ghi", " jkl", " mno ", "")) } :: Nil }), test("complex") { val line = "abc, def ,,\" g,\"\"h\"\",\ti\" , " - val result = parseLine[List](line, Options.Defaults) + val result = parseLine(line, Options.Defaults) assertTrue(result == List("abc", "def", "", " g,\"h\",\ti", "")) }, test("json") { val line = """abc,"{""data"": {""message"": ""blah \""quoted\""\n pos 123""}, ""type"": ""unhandled""}",xyz""" - val result = parseLine[List](line, Options.Defaults) + val result = parseLine(line, Options.Defaults) assertTrue(result == List("abc", """{"data": {"message": "blah \"quoted\"\n pos 123"}, "type": "unhandled"}""", "xyz")) }, ) } - - override protected def parse(lines: Iterable[String], options: CsvParser.Options) = { - val input = lines.mkString("\n").grouped(8192) - val result = CsvParser.parse[List](input, options) - ZIO.succeed(Chunk.fromIterator(result)) - } -} - -trait CsvParserParserSuite { self: ZIOSpecDefault => - - protected def parse( - lines: Iterable[String], - options: CsvParser.Options, - ): ZIO[Any, Throwable, Chunk[List[String]]] - - protected def parserSuite = suite("parser")( - test("lots") { - def line(i: Int) = List("basic string", " \"quoted \nstring\" ", i.toString, "456.789", "true").mkString(",") - - val lines = (1 to 10).map(line(_)) - parse(lines, CsvParser.Options.Defaults).map { result => - assertTrue(result.length == 10) - } - }, - suite("comment prefix")({ - val lines = List( - "a,b,c", - "#a,b,c", - "#", - " #", - "d,e,f", - ) - - test("no comments") { - val opts = CsvParser.Options.Defaults.copy(commentPrefix = None, trim = CsvParser.Options.Trim.False) - parse(lines, opts).map { result => - assertTrue(result == Chunk( - List("a", "b", "c"), - List("#a", "b", "c"), - List("#"), - List(" #"), - List("d", "e", "f"), - )) - } - } :: - test("false") { - val opts = CsvParser.Options.Defaults.copy(commentPrefix = Some("#")) - parse(lines, opts).map { result => - assertTrue(result == Chunk( - List("a", "b", "c"), - List("d", "e", "f"), - )) - } - } :: Nil - }), - suite("skip blank rows")({ - val lines = List( - "a,b,c", - "", - " ", - "d,e,f", - ) - - test("true") { - val opts = CsvParser.Options.Defaults.copy(skipBlankRows = true) - parse(lines, opts).map { result => - assertTrue(result == Chunk( - List("a", "b", "c"), - List("d", "e", "f"), - )) - } - } :: - test("false") { - val opts = CsvParser.Options.Defaults.copy(skipBlankRows = false) - parse(lines, opts).map { result => - assertTrue(result == Chunk( - List("a", "b", "c"), - List(""), - List(""), - List("d", "e", "f"), - )) - } - } :: Nil - }), - ) } diff --git a/modules/core/src/test/scala/ceesvee/CsvParserVectorSpec.scala b/modules/core/src/test/scala/ceesvee/CsvParserVectorSpec.scala new file mode 100644 index 0000000..7e7937f --- /dev/null +++ b/modules/core/src/test/scala/ceesvee/CsvParserVectorSpec.scala @@ -0,0 +1,40 @@ +package ceesvee + +import zio.Chunk +import zio.ZIO +import zio.test.ZIOSpecDefault + +import java.nio.charset.StandardCharsets + +object CsvParserVectorSpec extends ZIOSpecDefault + with CsvParserParserSuite + with CsvSplitStringsSuite[CsvParserVector.State] + with CsvParserLineSuite { + + private val charset = StandardCharsets.UTF_8 + + override val spec = suite("CsvParserVector")( + parserSuite, + splitStringsSuite, + parseLineSuite, + ) + + override protected def parse(lines: Iterable[String], options: CsvParser.Options) = { + val input = lines.mkString("\n").grouped(8192).map(_.getBytes(charset)) + val result = CsvParserVector.parse[List](input, charset, options) + ZIO.succeed(Chunk.fromIterator(result)) + } + + override protected def parseLine(line: String, options: CsvParser.Options) = { + CsvParserVector.parseLine[List](line.getBytes(charset), charset, options) + } + + override protected def splitStrings(strings: List[String], state: CsvParserVector.State) = { + val input = strings.mkString("").getBytes(charset) + val (s, o) = CsvParserVector.splitBytes[List](input, state) + (s, o.map(new String(_, charset))) + } + + override protected def initialState = CsvParserVector.State.initial + override protected def stateLeftover(s: CsvParserVector.State) = new String(s.leftover, charset) +} diff --git a/modules/zio/src/main/scala/ceesvee/zio/ZioCsvParser.scala b/modules/zio/src/main/scala/ceesvee/zio/ZioCsvParser.scala index 888aad8..2731939 100644 --- a/modules/zio/src/main/scala/ceesvee/zio/ZioCsvParser.scala +++ b/modules/zio/src/main/scala/ceesvee/zio/ZioCsvParser.scala @@ -5,7 +5,7 @@ import _root_.zio.Chunk import _root_.zio.NonEmptyChunk import _root_.zio.Ref import _root_.zio.Scope -import _root_.zio.Trace +import _root_.zio.Trace as ZIOTrace import _root_.zio.ZIO import _root_.zio.stream.ZPipeline import _root_.zio.stream.ZSink @@ -14,7 +14,6 @@ import ceesvee.CsvParser import ceesvee.CsvReader object ZioCsvParser { - import CsvParser.Error import CsvParser.State import CsvParser.ignoreLine import CsvParser.parseLine @@ -28,8 +27,8 @@ object ZioCsvParser { stream: ZStream[R, E, String], options: CsvReader.Options, )(implicit - trace: Trace, - ): ZIO[Scope & R, Either[E, Error], (Chunk[String], ZStream[Any, Either[E, Error], Chunk[String]])] = { + trace: ZIOTrace, + ): ZIO[Scope & R, Error[E], (Chunk[String], ZStream[Any, Error[E], Chunk[String]])] = { stream.mapError(Left(_)).peel { extractFirstLine(options).mapError(Right(_)) }.map { case ((headers, state, records), s) => @@ -37,30 +36,42 @@ object ZioCsvParser { } } - private def extractFirstLine(options: CsvReader.Options)(implicit trace: Trace) = { + private def extractFirstLine(options: CsvReader.Options)(implicit trace: ZIOTrace) = { + def process(state: State, strings: Chunk[String]) = { + val (newState, lines) = splitStrings(strings, state) + val records = lines.filter(str => !ignoreLine(str, options)).map(parseLine[Chunk](_, options)) + (newState, records) + } + + extractFirstLine_(State.initial, options)(_.leftover.length, process) + } + + private[zio] def extractFirstLine_[A, S](initialState: S, options: CsvReader.Options)( + leftoverLength: S => Int, + process: (S, Chunk[A]) => (S, Iterable[Chunk[String]]), + )(implicit trace: ZIOTrace) = { val initial: Chunk[Chunk[String]] = Chunk.empty @SuppressWarnings(Array("org.wartremover.warts.IterableOps")) - def done(state: State, records: Chunk[Chunk[String]]) = { + def done(state: S, records: Chunk[Chunk[String]]) = { NonEmptyChunk.fromChunk(records).map { rs => Push.emit((rs.head, state, rs.tail), Chunk.empty) } } - val push = Ref.make((State.initial, initial)).map { stateRef => (chunk: Option[Chunk[String]]) => + val push = Ref.make((initialState, initial)).map { stateRef => (chunk: Option[Chunk[A]]) => chunk match { case None => stateRef.get.flatMap { case (state, lines) => done(state, lines).getOrElse(Push.emit((Chunk.empty, state, lines), Chunk.empty)) } - case Some(strings) => + case Some(chunks) => stateRef.get.flatMap { case (state, records) => - if (state.leftover.length > options.maximumLineLength) { - Push.fail(Error.LineTooLong(options.maximumLineLength), Chunk.empty) + if (leftoverLength(state) > options.maximumLineLength) { + Push.fail(CsvParser.Error.LineTooLong(options.maximumLineLength), Chunk.empty) } else { - val (newState, lines) = splitStrings(strings, state) - val moreRecords = lines.filter(str => !ignoreLine(str, options)).map(parseLine[Chunk](_, options)) + val (newState, moreRecords) = process(state, chunks) val _records = records ++ moreRecords done(newState, _records).getOrElse(stateRef.set((newState, _records)) *> Push.more) } @@ -84,11 +95,11 @@ object ZioCsvParser { */ def parse( options: CsvParser.Options, - )(implicit trace: Trace): ZPipeline[Any, Error, String, Chunk[String]] = { + )(implicit trace: ZIOTrace): ZPipeline[Any, CsvParser.Error, String, Chunk[String]] = { _parse(State.initial, options) } - private[ceesvee] def _parse(state: State, options: CsvParser.Options)(implicit trace: Trace) = { + private[ceesvee] def _parse(state: State, options: CsvParser.Options)(implicit trace: ZIOTrace) = { _splitLines(state, options) >>> ZPipeline.filter[String](str => !ignoreLine(str, options)) >>> ZPipeline.map(parseLine[Chunk](_, options)) @@ -101,14 +112,14 @@ object ZioCsvParser { */ def splitLines( options: CsvParser.Options, - )(implicit trace: Trace): ZPipeline[Any, Error, String, String] = { + )(implicit trace: ZIOTrace): ZPipeline[Any, CsvParser.Error, String, String] = { _splitLines(State.initial, options) } private def _splitLines( state: State, options: CsvParser.Options, - )(implicit trace: Trace) = ZPipeline.fromPush { + )(implicit trace: ZIOTrace) = ZPipeline.fromPush { Ref.make(state).map { stateRef => (chunk: Option[Chunk[String]]) => chunk match { case None => @@ -118,7 +129,7 @@ object ZioCsvParser { case Some(strings) => stateRef.get.flatMap { case State(leftover, _, _) => - ZIO.fail(Error.LineTooLong(options.maximumLineLength)) + ZIO.fail(CsvParser.Error.LineTooLong(options.maximumLineLength)) .when(leftover.length > options.maximumLineLength) } *> stateRef.modify(splitStrings(strings, _).swap) } diff --git a/modules/zio/src/main/scala/ceesvee/zio/ZioCsvParserVector.scala b/modules/zio/src/main/scala/ceesvee/zio/ZioCsvParserVector.scala new file mode 100644 index 0000000..f1c993c --- /dev/null +++ b/modules/zio/src/main/scala/ceesvee/zio/ZioCsvParserVector.scala @@ -0,0 +1,96 @@ +package ceesvee.zio + +import _root_.zio.Chunk +import _root_.zio.Ref +import _root_.zio.Scope +import _root_.zio.Trace as ZIOTrace +import _root_.zio.ZIO +import _root_.zio.stream.ZPipeline +import _root_.zio.stream.ZStream +import ceesvee.CsvParser +import ceesvee.CsvParserVector +import ceesvee.CsvReader + +import java.nio.charset.Charset + +object ZioCsvParserVector { + import CsvParserVector.State + import CsvParserVector.parseLine + import CsvParserVector.splitBytes + + /** + * Turns a stream of strings into a stream of CSV records extracting the first + * record. + */ + def parseWithHeader[R, E]( + stream: ZStream[R, E, Byte], + charset: Charset, + options: CsvReader.Options, + )(implicit + trace: ZIOTrace, + ): ZIO[Scope & R, Error[E], (Chunk[String], ZStream[Any, Error[E], Chunk[String]])] = { + stream.mapError(Left(_)).peel { + extractFirstLine(charset, options).mapError(Right(_)) + }.map { case ((headers, state, records), s) => + (headers, ZStream.fromChunk(records) ++ (s >>> _parse(state, charset, options).mapError(Right(_)))) + } + } + + private def extractFirstLine(charset: Charset, options: CsvReader.Options)(implicit trace: ZIOTrace) = { + def process(state: State, bytes: Chunk[Byte]) = { + val (newState, lines) = splitBytes(bytes.toArray, state) + val records = lines.map(parseLine[Chunk](_, charset, options)).filter(_ != null) + (newState, records) + } + + ZioCsvParser.extractFirstLine_(State.initial, options)(_.leftover.length, process) + } + + /** + * Turns a stream of strings into a stream of CSV records. + */ + def parse( + charset: Charset, + options: CsvParser.Options, + )(implicit trace: ZIOTrace): ZPipeline[Any, CsvParser.Error, Byte, Chunk[String]] = { + _parse(State.initial, charset, options) + } + + private[ceesvee] def _parse(state: State, charset: Charset, options: CsvParser.Options)(implicit trace: ZIOTrace) = { + _splitLines(state, options) >>> + ZPipeline.map(parseLine[Chunk](_, charset, options)) >>> + ZPipeline.filter[Chunk[String]](_ != null) + } + + /** + * Split strings into CSV lines using both '\n' and '\r\n' as delimiters. + * + * Delimiters within double-quotes are ignored. + */ + def splitLines( + charset: Charset, + options: CsvParser.Options, + )(implicit trace: ZIOTrace): ZPipeline[Any, CsvParser.Error, Byte, String] = { + _splitLines(State.initial, options).map(new String(_, charset)) + } + + private def _splitLines( + state: State, + options: CsvParser.Options, + )(implicit trace: ZIOTrace) = ZPipeline.fromPush { + Ref.make(state).map { stateRef => (chunk: Option[Chunk[Byte]]) => + chunk match { + case None => + stateRef.getAndSet(State.initial).map { s => + if (s.leftover.isEmpty) Chunk.empty else Chunk(s.leftover) + } + + case Some(bytes) => + stateRef.get.flatMap { s => + ZIO.fail(CsvParser.Error.LineTooLong(options.maximumLineLength)) + .when(s.leftover.length > options.maximumLineLength) + } *> stateRef.modify(splitBytes[Chunk](bytes.toArray, _).swap) + } + } + } +} diff --git a/modules/zio/src/main/scala/ceesvee/zio/ZioCsvReader.scala b/modules/zio/src/main/scala/ceesvee/zio/ZioCsvReader.scala index 5beaeca..2bc10ad 100644 --- a/modules/zio/src/main/scala/ceesvee/zio/ZioCsvReader.scala +++ b/modules/zio/src/main/scala/ceesvee/zio/ZioCsvReader.scala @@ -1,8 +1,9 @@ package ceesvee.zio -import _root_.zio.Cause +import _root_.zio.Chunk +import _root_.zio.Exit import _root_.zio.Scope -import _root_.zio.Trace +import _root_.zio.Trace as ZIOTrace import _root_.zio.ZIO import _root_.zio.stream.ZPipeline import _root_.zio.stream.ZStream @@ -12,7 +13,6 @@ import ceesvee.CsvReader import ceesvee.CsvRecordDecoder object ZioCsvReader { - import CsvParser.Error /** * Turns a stream of strings into a stream of decoded CSV records. @@ -24,18 +24,25 @@ object ZioCsvReader { header: CsvHeader[T], options: CsvReader.Options, )(implicit - trace: Trace, - ): ZIO[Scope & R, Either[Either[E, Error], CsvHeader.MissingHeaders], ZStream[R, Either[E, Error], Either[CsvHeader.Errors, T]]] = { - for { - tuple <- ZioCsvParser.parseWithHeader(stream, options).mapError(Left(_)) - (headerFields, s) = tuple - decoder <- header.create(headerFields) match { - case Left(error) => ZIO.refailCause(Cause.fail(error)).mapError(Right(_)) - case Right(decoder) => ZIO.succeed(decoder) - } - } yield { - s.map(decoder.decode(_)) + trace: ZIOTrace, + ): ZIO[Scope & R, Either[Error[E], CsvHeader.MissingHeaders], ZStream[R, Error[E], Either[CsvHeader.Errors, T]]] = { + decodeWithHeader_[R, E, T](ZioCsvParser.parseWithHeader(stream, options), header) + } + + private[zio] def decodeWithHeader_[R, E, T]( + parseWithHeader: ZIO[Scope & R, Error[E], (Chunk[String], ZStream[Any, Error[E], Chunk[String]])], + header: CsvHeader[T], + )(implicit + trace: ZIOTrace, + ): ZIO[Scope & R, Either[Error[E], CsvHeader.MissingHeaders], ZStream[R, Error[E], Either[CsvHeader.Errors, T]]] = for { + tuple <- parseWithHeader.mapError(Left(_)) + (headerFields, s) = tuple + decoder <- header.create(headerFields) match { + case Left(error) => Exit.fail(error).mapError(Right(_)) + case Right(decoder) => Exit.succeed(decoder) } + } yield { + s.map(decoder.decode(_)) } /** @@ -45,8 +52,8 @@ object ZioCsvReader { options: CsvReader.Options, )(implicit D: CsvRecordDecoder[T], - trace: Trace, - ): ZPipeline[Any, Error, String, Either[CsvRecordDecoder.Errors, T]] = { + trace: ZIOTrace, + ): ZPipeline[Any, CsvParser.Error, String, Either[CsvRecordDecoder.Errors, T]] = { ZioCsvParser.parse(options) >>> ZPipeline.map(D.decode(_)) } } diff --git a/modules/zio/src/main/scala/ceesvee/zio/ZioCsvReaderVector.scala b/modules/zio/src/main/scala/ceesvee/zio/ZioCsvReaderVector.scala new file mode 100644 index 0000000..ef35b1b --- /dev/null +++ b/modules/zio/src/main/scala/ceesvee/zio/ZioCsvReaderVector.scala @@ -0,0 +1,46 @@ +package ceesvee.zio + +import _root_.zio.Scope +import _root_.zio.Trace as ZIOTrace +import _root_.zio.ZIO +import _root_.zio.stream.ZPipeline +import _root_.zio.stream.ZStream +import ceesvee.CsvHeader +import ceesvee.CsvParser +import ceesvee.CsvReader +import ceesvee.CsvRecordDecoder + +import java.nio.charset.Charset + +object ZioCsvReaderVector { + + /** + * Turns a stream of strings into a stream of decoded CSV records. + * + * CSV lines are reordered based on the given headers. + */ + def decodeWithHeader[R, E, T]( + stream: ZStream[R, E, Byte], + header: CsvHeader[T], + charset: Charset, + options: CsvReader.Options, + )(implicit + trace: ZIOTrace, + ): ZIO[Scope & R, Either[Error[E], CsvHeader.MissingHeaders], ZStream[R, Error[E], Either[CsvHeader.Errors, T]]] = { + val parser = ZioCsvParserVector.parseWithHeader(stream, charset, options) + ZioCsvReader.decodeWithHeader_[R, E, T](parser, header) + } + + /** + * Turns a stream of strings into a stream of decoded CSV records. + */ + def decode[T]( + charset: Charset, + options: CsvReader.Options, + )(implicit + D: CsvRecordDecoder[T], + trace: ZIOTrace, + ): ZPipeline[Any, CsvParser.Error, Byte, Either[CsvRecordDecoder.Errors, T]] = { + ZioCsvParserVector.parse(charset, options) >>> ZPipeline.map(D.decode(_)) + } +} diff --git a/modules/zio/src/main/scala/ceesvee/zio/package.scala b/modules/zio/src/main/scala/ceesvee/zio/package.scala new file mode 100644 index 0000000..3322725 --- /dev/null +++ b/modules/zio/src/main/scala/ceesvee/zio/package.scala @@ -0,0 +1,6 @@ +package ceesvee + +package object zio { + + type Error[E] = Either[E, CsvParser.Error] +} diff --git a/modules/zio/src/test/scala/ceesvee/zio/CsvParserSpec.scala b/modules/zio/src/test/scala/ceesvee/zio/ZioCsvParserSpec.scala similarity index 77% rename from modules/zio/src/test/scala/ceesvee/zio/CsvParserSpec.scala rename to modules/zio/src/test/scala/ceesvee/zio/ZioCsvParserSpec.scala index 63794ce..0dddc4a 100644 --- a/modules/zio/src/test/scala/ceesvee/zio/CsvParserSpec.scala +++ b/modules/zio/src/test/scala/ceesvee/zio/ZioCsvParserSpec.scala @@ -4,9 +4,9 @@ import ceesvee.CsvParser import zio.stream.ZStream import zio.test.ZIOSpecDefault -object CsvParserSpec extends ZIOSpecDefault with ceesvee.CsvParserParserSuite { +object ZioCsvParserSpec extends ZIOSpecDefault with ceesvee.CsvParserParserSuite { - override val spec = suite("CsvParser")( + override val spec = suite("ZioCsvParser")( parserSuite, ) diff --git a/modules/zio/src/test/scala/ceesvee/zio/ZioCsvParserVectorSpec.scala b/modules/zio/src/test/scala/ceesvee/zio/ZioCsvParserVectorSpec.scala new file mode 100644 index 0000000..4e3ed01 --- /dev/null +++ b/modules/zio/src/test/scala/ceesvee/zio/ZioCsvParserVectorSpec.scala @@ -0,0 +1,24 @@ +package ceesvee.zio + +import ceesvee.CsvParser +import zio.stream.ZStream +import zio.test.ZIOSpecDefault + +import java.nio.charset.StandardCharsets + +object ZioCsvParserVectorSpec extends ZIOSpecDefault with ceesvee.CsvParserParserSuite { + + override val spec = suite("ZioCsvParserVector")( + parserSuite, + ) + + override protected def parse(lines: Iterable[String], options: CsvParser.Options) = { + val charset = StandardCharsets.UTF_8 + val input = ZStream.fromIterable(lines).intersperse("\n").rechunk(4096).mapConcat(_.getBytes(charset)) + input + .via(ZioCsvParserVector.parse(charset, options)) + .map(_.toList) + .runCollect + .mapError(e => new RuntimeException(s"failed to parse: $e")) + } +} diff --git a/tests/src/test/scala/ceesvee/tests/RealWorldCsvSpec.scala b/tests/src/test/scala/ceesvee/tests/RealWorldCsvSpec.scala index 709569b..64f6b10 100644 --- a/tests/src/test/scala/ceesvee/tests/RealWorldCsvSpec.scala +++ b/tests/src/test/scala/ceesvee/tests/RealWorldCsvSpec.scala @@ -10,6 +10,7 @@ import ceesvee.tests.model.NZGreenhouseGasEmissions import ceesvee.tests.model.UkCausewayCoast import ceesvee.tests.model.UkPropertySalesPricePaid import ceesvee.zio.ZioCsvReader +import ceesvee.zio.ZioCsvReaderVector import zio.ZIO import zio.durationInt import zio.stream.ZPipeline @@ -28,6 +29,7 @@ import java.nio.file.Paths object RealWorldCsvSpec extends ZIOSpecDefault { + private val charset = StandardCharsets.UTF_8 private val options = CsvReader.Options.Defaults override val spec = suite("RealWorldCsv")( @@ -68,7 +70,7 @@ object RealWorldCsvSpec extends ZIOSpecDefault { } }, test("zio") { - val stream = readFileZio(path) + val stream = readFileZio(path).via(ZPipeline.utfDecode) ZIO.scoped[Any] { ZioCsvReader.decodeWithHeader(stream, UkCausewayCoast.csvHeader, options).flatMap { s => s.runCollect.mapError(Left(_)) @@ -77,6 +79,16 @@ object RealWorldCsvSpec extends ZIOSpecDefault { assertResult(result) } }, + test("zio vector") { + val stream = readFileZio(path).drop(3) // UTF8 BOM + ZIO.scoped[Any] { + ceesvee.zio.ZioCsvReaderVector.decodeWithHeader(stream, UkCausewayCoast.csvHeader, charset, options).flatMap { s => + s.runCollect.mapError(Left(_)) + } + }.map { result => + assertResult(result) + } + }, ) }*), suite("NZ greenhouse gas emissions 2019")({ @@ -111,6 +123,14 @@ object RealWorldCsvSpec extends ZIOSpecDefault { val pipeline = ZioCsvReader.decode(options)(decoder, implicitly).mapError { case e: CsvParser.Error.LineTooLong => e }.andThen(ZPipeline.mapZIO(ZIO.fromEither(_))) + readFileZio(path).via(ZPipeline.utfDecode).via(pipeline).runCount.map { count => + assertTrue(count == total) + } + }, + test("zio vector") { + val pipeline = ZioCsvReaderVector.decode(charset, options)(decoder, implicitly).mapError { + case e: CsvParser.Error.LineTooLong => e + }.andThen(ZPipeline.mapZIO(ZIO.fromEither(_))) readFileZio(path).via(pipeline).runCount.map { count => assertTrue(count == total) } @@ -138,7 +158,7 @@ object RealWorldCsvSpec extends ZIOSpecDefault { } }, test("zio") { - val stream = readFileZio(path) + val stream = readFileZio(path).via(ZPipeline.utfDecode) ZIO.scoped[Any] { ZioCsvReader.decodeWithHeader(stream, header, options).flatMap { s => s.collectRight.runCount.mapError(Left(_)) @@ -194,6 +214,6 @@ object RealWorldCsvSpec extends ZIOSpecDefault { } private def readFileZio(path: Path) = { - ZStream.fromPath(path, chunkSize = 16384) >>> ZPipeline.utfDecode + ZStream.fromPath(path, chunkSize = 16384) } }