|
1 | 1 | use std::fs::File; |
2 | 2 | use std::io; |
3 | 3 | use std::io::{BufReader, Read, Seek, SeekFrom}; |
| 4 | +use std::ops::AddAssign; |
4 | 5 | use tokio::io::AsyncSeek; |
5 | 6 | use zstd_framed; |
6 | 7 | use zstd_framed::{ZstdReader, ZstdWriter}; |
@@ -149,16 +150,81 @@ pub fn new_zstd_writer<'a, W: io::Write>(inner: W, max_frame_size: Option<u32>) |
149 | 150 | .build() |
150 | 151 | } |
151 | 152 |
|
| 153 | +/// Compress `src` to `dst` in one go. |
| 154 | +/// |
| 155 | +/// Like a `FnOnce` closure, but polymorphic over the arguments. |
| 156 | +pub trait CompressOnce { |
| 157 | + fn compress(self, src: impl io::Read, dst: impl io::Write) -> io::Result<CompressionStats>; |
| 158 | +} |
| 159 | + |
| 160 | +/// Implements [CompressOnce] for the zstd compression algorithm. |
| 161 | +pub struct Zstd { |
| 162 | + /// If `Some`, add a seek table with `max_frame_size` to the compressed output. |
| 163 | + /// |
| 164 | + /// See [zstd_framed::writer::ZstdWriterBuilder::with_seek_table]. |
| 165 | + pub max_frame_size: Option<u32>, |
| 166 | +} |
| 167 | + |
| 168 | +impl CompressOnce for Zstd { |
| 169 | + fn compress(self, src: impl io::Read, dst: impl io::Write) -> io::Result<CompressionStats> { |
| 170 | + compress_with_zstd(src, dst, self.max_frame_size) |
| 171 | + } |
| 172 | +} |
| 173 | + |
| 174 | +#[derive(Clone, Copy, Debug, Default)] |
| 175 | +pub struct CompressionStats { |
| 176 | + pub bytes_read: u64, |
| 177 | + pub bytes_written: u64, |
| 178 | +} |
| 179 | + |
| 180 | +impl AddAssign for CompressionStats { |
| 181 | + fn add_assign(&mut self, rhs: Self) { |
| 182 | + self.bytes_read += rhs.bytes_read; |
| 183 | + self.bytes_written += rhs.bytes_written; |
| 184 | + } |
| 185 | +} |
| 186 | + |
152 | 187 | pub fn compress_with_zstd<W: io::Write, R: io::Read>( |
153 | 188 | mut src: R, |
154 | | - mut dst: W, |
| 189 | + dst: W, |
155 | 190 | max_frame_size: Option<u32>, |
156 | | -) -> io::Result<()> { |
157 | | - let mut writer = new_zstd_writer(&mut dst, max_frame_size)?; |
158 | | - io::copy(&mut src, &mut writer)?; |
159 | | - writer.shutdown()?; |
160 | | - drop(writer); |
161 | | - Ok(()) |
| 191 | +) -> io::Result<CompressionStats> { |
| 192 | + /// [io::Write] wrapper that counts how many bytes were written. |
| 193 | + struct Writer<W> { |
| 194 | + bytes_written: u64, |
| 195 | + inner: W, |
| 196 | + } |
| 197 | + |
| 198 | + impl<W: io::Write> io::Write for Writer<W> { |
| 199 | + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { |
| 200 | + let n = self.inner.write(buf)?; |
| 201 | + self.bytes_written += n as u64; |
| 202 | + Ok(n) |
| 203 | + } |
| 204 | + |
| 205 | + fn flush(&mut self) -> io::Result<()> { |
| 206 | + self.inner.flush() |
| 207 | + } |
| 208 | + } |
| 209 | + |
| 210 | + // Wrap `dst` in [Writer], and use it as the sink for the zstd writer, |
| 211 | + // such that we can determine how many (compressed) bytes came out at the end. |
| 212 | + let mut dst = Writer { |
| 213 | + bytes_written: 0, |
| 214 | + inner: dst, |
| 215 | + }; |
| 216 | + let mut zstd_writer = new_zstd_writer(&mut dst, max_frame_size)?; |
| 217 | + |
| 218 | + let bytes_read = io::copy(&mut src, &mut zstd_writer)?; |
| 219 | + zstd_writer.shutdown()?; |
| 220 | + drop(zstd_writer); |
| 221 | + |
| 222 | + let stats = CompressionStats { |
| 223 | + bytes_read, |
| 224 | + bytes_written: dst.bytes_written, |
| 225 | + }; |
| 226 | + |
| 227 | + Ok(stats) |
162 | 228 | } |
163 | 229 |
|
164 | 230 | pub use async_impls::AsyncCompressReader; |
|
0 commit comments