Skip to content

Commit cb272e7

Browse files
committed
Merge branch 'phase1-shared-buffer-layer-policy'
Integrates buffer layer orchestration deepening and frequency table seam improvements with the architecture consolidation refactor: - Keep HEAD's streaming integration (streaming.go merged into algorithms) - Keep phase1's frequency helpers (BuildFrequenciesFromReader, etc.) - Keep HEAD's byte slice APIs (WriteFrequenciesToBytes, etc.) - Use phase1's strict frequency validation (ReadFrequenciesExact) - Add phase1's BuildCumulativeStrict for corrupt table detection
2 parents 2bf38c7 + b028d24 commit cb272e7

24 files changed

Lines changed: 1548 additions & 425 deletions

algorithms/arithmetic/go/arithmetic.go

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -197,15 +197,14 @@ func WriteFrequencies(w io.Writer, freq []uint32) error {
197197
}
198198

199199
// ReadFrequencies deserializes a frequency table from the reader.
200-
// This is an alias for codec.ReadFrequencies for backward compatibility.
200+
// This is an alias for codec.ReadFrequenciesExact for backward compatibility.
201201
func ReadFrequencies(r io.Reader) ([]uint32, error) {
202-
return codec.ReadFrequencies(r, SymbolLimit)
202+
return codec.ReadFrequenciesExact(r, SymbolLimit)
203203
}
204204

205205
// BuildFrequenciesFromFile reads a file and counts byte frequencies.
206206
// The frequencies are scaled to fit within MaxTotal.
207207
func BuildFrequenciesFromFile(path string) ([]uint32, error) {
208-
freq := make([]uint32, SymbolLimit)
209208
f, err := os.Open(path)
210209
if err != nil {
211210
return nil, fmt.Errorf("cannot open input file for reading: %s: %w", path, err)
@@ -220,16 +219,10 @@ func BuildFrequenciesFromFile(path string) ([]uint32, error) {
220219
return nil, fmt.Errorf("input file too large (max %d bytes)", MaxInputSize)
221220
}
222221

223-
r := bufio.NewReader(f)
224-
for {
225-
b, err := r.ReadByte()
226-
if err != nil {
227-
break
228-
}
229-
freq[int(b)]++
222+
freq, err := codec.BuildScaledFrequenciesFromReader(bufio.NewReader(f), MaxTotal)
223+
if err != nil {
224+
return nil, fmt.Errorf("cannot read input file: %s: %w", path, err)
230225
}
231-
freq[EOFSymbol] = 1
232-
ScaleFrequencies(freq)
233226
return freq, nil
234227
}
235228

@@ -243,13 +236,11 @@ func Encode(input io.Reader, w io.Writer) error {
243236
return fmt.Errorf("input too large (max %d bytes)", MaxInputSize)
244237
}
245238

246-
freq := make([]uint32, SymbolLimit)
247-
for _, b := range data {
248-
freq[int(b)]++
239+
freq, err := codec.BuildScaledFrequenciesChecked(data, MaxTotal)
240+
if err != nil {
241+
return fmt.Errorf("failed to count input frequencies: %w", err)
249242
}
250-
freq[EOFSymbol] = 1
251-
ScaleFrequencies(freq)
252-
cumulative := BuildCumulative(freq)
243+
cumulative := codec.BuildCumulative(freq)
253244

254245
if _, err := w.Write([]byte{'A', 'E', 'N', 'C'}); err != nil {
255246
return err
@@ -285,7 +276,10 @@ func Decode(r io.Reader, w io.Writer) error {
285276
if err != nil {
286277
return err
287278
}
288-
cumulative := BuildCumulative(freq)
279+
cumulative, err := codec.BuildCumulativeStrict(freq, "invalid frequency table")
280+
if err != nil {
281+
return err
282+
}
289283

290284
bw := bufio.NewWriter(w)
291285
bitReader := codec.NewBitReader(br)

algorithms/arithmetic/go/arithmetic_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@ package arithmetic
22

33
import (
44
"bytes"
5+
"encoding/binary"
6+
"errors"
57
"os"
68
"path/filepath"
79
"testing"
10+
11+
"github.com/LessUp/compress-kit/algorithms/shared/go/codec"
812
)
913

1014
func TestCompressDecompressRoundTrip(t *testing.T) {
@@ -114,3 +118,30 @@ func TestCompressDecompressAllBytes(t *testing.T) {
114118
t.Fatalf("round-trip mismatch")
115119
}
116120
}
121+
122+
func TestDecodeRejectsAllZeroFrequencyTable(t *testing.T) {
123+
var encoded bytes.Buffer
124+
if _, err := encoded.Write([]byte("AENC")); err != nil {
125+
t.Fatalf("write magic: %v", err)
126+
}
127+
if err := binary.Write(&encoded, binary.LittleEndian, uint32(codec.SymbolLimit)); err != nil {
128+
t.Fatalf("write count: %v", err)
129+
}
130+
for i := 0; i < codec.SymbolLimit; i++ {
131+
if err := binary.Write(&encoded, binary.LittleEndian, uint32(0)); err != nil {
132+
t.Fatalf("write freq[%d]: %v", i, err)
133+
}
134+
}
135+
if _, err := encoded.Write([]byte{0xFF, 0xFF, 0xFF, 0xFF}); err != nil {
136+
t.Fatalf("write trailer: %v", err)
137+
}
138+
139+
var decoded bytes.Buffer
140+
err := Decode(bytes.NewReader(encoded.Bytes()), &decoded)
141+
if err == nil {
142+
t.Fatal("expected decode to reject all-zero frequency table")
143+
}
144+
if !errors.Is(err, codec.ErrCorrupt) {
145+
t.Fatalf("expected corrupt error, got %v", err)
146+
}
147+
}

algorithms/arithmetic/rust/src/lib.rs

Lines changed: 38 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
use std::io;
55

6-
use compresskit_codec::codec::BitWriter;
7-
8-
const SYMBOL_LIMIT: usize = 257;
9-
const EOF_SYMBOL: u32 = (SYMBOL_LIMIT - 1) as u32;
6+
use compresskit_codec::codec::{
7+
build_cumulative, build_cumulative_strict, build_scaled_frequencies, read_frequencies_exact,
8+
write_frequencies, BitWriter, EOF_SYMBOL, SYMBOL_LIMIT,
9+
};
1010
const MAX_TOTAL: u32 = 1 << 24;
1111
const MAX_OUTPUT_SIZE: usize = 1024 * 1024 * 1024; // 1 GiB
1212
const STATE_BITS: u64 = 32;
@@ -15,60 +15,13 @@ const HALF_RANGE: u64 = FULL_RANGE >> 1;
1515
const FIRST_QUARTER: u64 = HALF_RANGE >> 1;
1616
const THIRD_QUARTER: u64 = FIRST_QUARTER * 3;
1717

18-
fn scale_frequencies(freq: &mut [u32]) {
19-
let total: u64 = freq.iter().map(|&f| f as u64).sum();
20-
if total == 0 {
21-
for f in freq.iter_mut() {
22-
*f = 1;
23-
}
24-
return;
25-
}
26-
if total <= MAX_TOTAL as u64 {
27-
return;
28-
}
29-
let mut new_total = 0u64;
30-
for f in freq.iter_mut() {
31-
if *f == 0 {
32-
continue;
33-
}
34-
let mut scaled = (*f as u64 * MAX_TOTAL as u64) / total;
35-
if scaled == 0 {
36-
scaled = 1;
37-
}
38-
*f = scaled as u32;
39-
new_total += scaled;
40-
}
41-
if new_total == 0 {
42-
let base = MAX_TOTAL / freq.len() as u32;
43-
for f in freq.iter_mut() {
44-
*f = if base == 0 { 1 } else { base };
45-
}
46-
}
47-
}
48-
49-
fn build_cumulative(freq: &[u32]) -> Vec<u32> {
50-
let mut cumulative = vec![0u32; freq.len() + 1];
51-
for (i, &f) in freq.iter().enumerate() {
52-
cumulative[i + 1] = cumulative[i] + f;
53-
}
54-
cumulative
55-
}
56-
5718
pub fn encode(input: &[u8]) -> Result<Vec<u8>, io::Error> {
58-
let mut freq = vec![0u32; SYMBOL_LIMIT];
59-
for &b in input {
60-
freq[b as usize] += 1;
61-
}
62-
freq[EOF_SYMBOL as usize] = 1;
63-
scale_frequencies(&mut freq);
19+
let freq = build_scaled_frequencies(input, MAX_TOTAL);
6420
let cumulative = build_cumulative(&freq);
6521

6622
let mut output = Vec::new();
6723
output.extend_from_slice(b"AENC");
68-
output.extend_from_slice(&(SYMBOL_LIMIT as u32).to_le_bytes());
69-
for &f in &freq {
70-
output.extend_from_slice(&f.to_le_bytes());
71-
}
24+
write_frequencies(&mut output, &freq);
7225

7326
let mut low = 0u64;
7427
let mut high = FULL_RANGE - 1;
@@ -160,28 +113,19 @@ pub fn decode(input: &[u8]) -> Result<Vec<u8>, io::Error> {
160113
return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid magic"));
161114
}
162115

163-
let count = u32::from_le_bytes([input[4], input[5], input[6], input[7]]) as usize;
164-
if count != SYMBOL_LIMIT {
165-
return Err(io::Error::new(
166-
io::ErrorKind::InvalidData,
167-
"invalid symbol count",
168-
));
169-
}
170-
171-
let mut pos = 8;
172-
let mut freq = vec![0u32; count];
173-
for f in freq.iter_mut() {
174-
if pos + 4 > input.len() {
175-
return Err(io::Error::new(
176-
io::ErrorKind::InvalidData,
177-
"truncated freq table",
178-
));
179-
}
180-
*f = u32::from_le_bytes([input[pos], input[pos + 1], input[pos + 2], input[pos + 3]]);
181-
pos += 4;
182-
}
183-
184-
let cumulative = build_cumulative(&freq);
116+
let mut pos = 4;
117+
let freq = read_frequencies_exact(
118+
input,
119+
&mut pos,
120+
SYMBOL_LIMIT,
121+
"truncated freq table",
122+
"truncated freq table",
123+
"invalid symbol count",
124+
)
125+
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err.message))?;
126+
127+
let cumulative = build_cumulative_strict(&freq, "invalid frequency table")
128+
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err.message))?;
185129
let total = cumulative[cumulative.len() - 1] as u64;
186130

187131
let mut low = 0u64;
@@ -269,17 +213,17 @@ pub fn decode(input: &[u8]) -> Result<Vec<u8>, io::Error> {
269213

270214
// Streaming adapters
271215
use compresskit_codec::codec::{
272-
io_error_to_codec_error, BufferedDecoder, BufferedEncoder, CodecError, Decoder, Encoder,
216+
io_error_to_codec_error, streaming_decoder, streaming_encoder, CodecError, Decoder, Encoder,
273217
};
274218

275219
/// Creates a new streaming Arithmetic encoder.
276220
pub fn new_encoder() -> impl Encoder {
277-
BufferedEncoder::new(arithmetic_encode)
221+
streaming_encoder(arithmetic_encode)
278222
}
279223

280224
/// Creates a new streaming Arithmetic decoder.
281225
pub fn new_decoder() -> impl Decoder {
282-
BufferedDecoder::new(arithmetic_decode)
226+
streaming_decoder(arithmetic_decode)
283227
}
284228

285229
fn arithmetic_encode(input: &[u8]) -> Result<Vec<u8>, CodecError> {
@@ -301,4 +245,20 @@ mod tests {
301245

302246
assert_eq!(decoded, vec![0x00]);
303247
}
248+
249+
#[test]
250+
fn decode_rejects_all_zero_frequency_table() {
251+
let mut encoded = Vec::new();
252+
encoded.extend_from_slice(b"AENC");
253+
encoded.extend_from_slice(&(257u32).to_le_bytes());
254+
for _ in 0..257 {
255+
encoded.extend_from_slice(&0u32.to_le_bytes());
256+
}
257+
encoded.extend_from_slice(&[0xFF; 4]);
258+
259+
let err = decode(&encoded).unwrap_err();
260+
261+
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
262+
assert_eq!(err.to_string(), "invalid frequency table");
263+
}
304264
}

algorithms/huffman/go/huffman.go

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ func BuildTree(freq []uint32) *Node {
102102

103103
// BuildFrequenciesFromFile reads the file and counts byte frequencies.
104104
func BuildFrequenciesFromFile(path string) ([]uint32, error) {
105-
freq := make([]uint32, SymbolLimit)
106105
f, err := os.Open(path)
107106
if err != nil {
108107
return nil, fmt.Errorf("cannot open input file: %s: %w", path, err)
@@ -117,15 +116,10 @@ func BuildFrequenciesFromFile(path string) ([]uint32, error) {
117116
return nil, fmt.Errorf("input file too large (max %d bytes)", MaxInputSize)
118117
}
119118

120-
r := bufio.NewReader(f)
121-
for {
122-
b, err := r.ReadByte()
123-
if err != nil {
124-
break
125-
}
126-
freq[int(b)]++
119+
freq, err := codec.BuildFrequenciesFromReader(bufio.NewReader(f))
120+
if err != nil {
121+
return nil, fmt.Errorf("cannot read input file: %s: %w", path, err)
127122
}
128-
freq[EOFSymbol] = 1
129123
return freq, nil
130124
}
131125

@@ -136,9 +130,9 @@ func WriteFrequencies(w io.Writer, freq []uint32) error {
136130
}
137131

138132
// ReadFrequencies deserializes a frequency table from the reader.
139-
// This is an alias for codec.ReadFrequencies for backward compatibility.
133+
// This is an alias for codec.ReadFrequenciesExact for backward compatibility.
140134
func ReadFrequencies(r io.Reader) ([]uint32, error) {
141-
return codec.ReadFrequencies(r, SymbolLimit)
135+
return codec.ReadFrequenciesExact(r, SymbolLimit)
142136
}
143137

144138
// BuildCodes generates Huffman codes for each symbol by traversing the tree.
@@ -170,11 +164,10 @@ func Encode(input io.Reader, w io.Writer) error {
170164
return fmt.Errorf("input too large (max %d bytes)", MaxInputSize)
171165
}
172166

173-
freq := make([]uint32, SymbolLimit)
174-
for _, b := range data {
175-
freq[int(b)]++
167+
freq, err := codec.BuildFrequenciesChecked(data)
168+
if err != nil {
169+
return fmt.Errorf("failed to count input frequencies: %w", err)
176170
}
177-
freq[EOFSymbol] = 1
178171

179172
root := BuildTree(freq)
180173
codes := make([]string, SymbolLimit)

0 commit comments

Comments
 (0)