Skip to content

Commit 86f6c4c

Browse files
authored
Merge pull request #8 from LessUp/copilot/cpp-frequency-hardening-20260526
fix: harden cpp frequency table handling
2 parents 455de60 + a289a24 commit 86f6c4c

6 files changed

Lines changed: 232 additions & 40 deletions

File tree

Makefile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
build: build-huffman build-arithmetic build-range build-rle
1010

1111
build-huffman:
12-
g++ -std=c++17 -O2 -Wall -Wextra -Werror -Ialgorithms/shared/cpp/include algorithms/shared/cpp/src/buffer_api.cpp algorithms/shared/cpp/src/cli_launcher.cpp algorithms/huffman/cpp/main.cpp -o algorithms/huffman/cpp/huffman_cpp
12+
g++ -std=c++17 -O2 -Wall -Wextra -Werror -Ialgorithms/shared/cpp/include algorithms/shared/cpp/src/buffer_api.cpp algorithms/shared/cpp/src/cli_launcher.cpp algorithms/shared/cpp/src/frequency_table.cpp algorithms/huffman/cpp/main.cpp -o algorithms/huffman/cpp/huffman_cpp
1313
go build -o algorithms/huffman/go/huffman_go ./algorithms/huffman/go/cmd
1414
cargo build --manifest-path algorithms/huffman/rust/Cargo.toml --bin huffman_rust --release
1515
cp algorithms/huffman/rust/target/release/huffman_rust algorithms/huffman/rust/huffman_rust
1616

1717
build-arithmetic:
18-
g++ -std=c++17 -O2 -Wall -Wextra -Werror -Ialgorithms/shared/cpp/include algorithms/shared/cpp/src/buffer_api.cpp algorithms/shared/cpp/src/cli_launcher.cpp algorithms/arithmetic/cpp/main.cpp -o algorithms/arithmetic/cpp/arithmetic_cpp
18+
g++ -std=c++17 -O2 -Wall -Wextra -Werror -Ialgorithms/shared/cpp/include algorithms/shared/cpp/src/buffer_api.cpp algorithms/shared/cpp/src/cli_launcher.cpp algorithms/shared/cpp/src/frequency_table.cpp algorithms/arithmetic/cpp/main.cpp -o algorithms/arithmetic/cpp/arithmetic_cpp
1919
go build -o algorithms/arithmetic/go/arithmetic_go ./algorithms/arithmetic/go/cmd
2020
cargo build --manifest-path algorithms/arithmetic/rust/Cargo.toml --bin arithmetic_rust --release
2121
cp algorithms/arithmetic/rust/target/release/arithmetic_rust algorithms/arithmetic/rust/arithmetic_rust
@@ -40,7 +40,7 @@ test: test-data \
4040
test-conformance test-cli-smoke
4141

4242
test-shared-cpp:
43-
g++ -std=c++17 -O2 -Wall -Wextra -Werror -DCOMPRESSKIT_NO_MAIN -Ialgorithms/shared/cpp/include algorithms/shared/cpp/src/buffer_api.cpp algorithms/shared/cpp/src/cli_launcher.cpp algorithms/huffman/cpp/main.cpp algorithms/arithmetic/cpp/main.cpp algorithms/range/cpp/main.cpp algorithms/rle/cpp/main.cpp algorithms/shared/cpp/tests/test_lifecycle.cpp -o algorithms/shared/cpp/tests/test_lifecycle
43+
g++ -std=c++17 -O2 -Wall -Wextra -Werror -DCOMPRESSKIT_NO_MAIN -Ialgorithms/shared/cpp/include algorithms/shared/cpp/src/buffer_api.cpp algorithms/shared/cpp/src/cli_launcher.cpp algorithms/shared/cpp/src/frequency_table.cpp algorithms/huffman/cpp/main.cpp algorithms/arithmetic/cpp/main.cpp algorithms/range/cpp/main.cpp algorithms/rle/cpp/main.cpp algorithms/shared/cpp/tests/test_lifecycle.cpp -o algorithms/shared/cpp/tests/test_lifecycle
4444
./algorithms/shared/cpp/tests/test_lifecycle
4545

4646
test-shared-go:

algorithms/arithmetic/cpp/main.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <vector>
66

77
#include "compresskit/buffer_api.hpp"
8+
#include "compresskit/frequency_table.hpp"
89

910
class BitWriter {
1011
public:
@@ -242,10 +243,17 @@ static std::vector<uint32_t> build_frequencies_from_file(const std::string& inpu
242243
if (!in) {
243244
return freq;
244245
}
245-
char c;
246-
while (in.get(c)) {
247-
unsigned char uc = static_cast<unsigned char>(c);
248-
freq[static_cast<uint32_t>(uc)]++;
246+
uint32_t overflow_symbol = 0;
247+
const auto status = compresskit::accumulate_frequencies(in, freq, &overflow_symbol);
248+
if (status == compresskit::FrequencyCountStatus::IO_ERROR) {
249+
std::cerr << "Failed to read input file\n";
250+
freq.clear();
251+
return freq;
252+
}
253+
if (status == compresskit::FrequencyCountStatus::OVERFLOW) {
254+
std::cerr << "Frequency overflow for symbol " << overflow_symbol << "\n";
255+
freq.clear();
256+
return freq;
249257
}
250258
freq[EOF_SYMBOL] = 1;
251259
scale_frequencies(freq);
@@ -266,30 +274,20 @@ static std::vector<uint32_t> build_cumulative(const std::vector<uint32_t>& freq)
266274
}
267275

268276
static void write_frequencies(std::ostream& out, const std::vector<uint32_t>& freq) {
269-
uint32_t count = static_cast<uint32_t>(freq.size());
270-
out.write(reinterpret_cast<const char*>(&count), sizeof(count));
271-
for (uint32_t v : freq) {
272-
out.write(reinterpret_cast<const char*>(&v), sizeof(v));
273-
}
277+
compresskit::write_frequency_table(out, freq);
274278
}
275279

276280
static bool read_frequencies(std::istream& in, std::vector<uint32_t>& freq) {
277281
uint32_t count = 0;
278-
in.read(reinterpret_cast<char*>(&count), sizeof(count));
279-
if (!in) {
282+
const auto status = compresskit::read_frequency_table(in, freq, SYMBOL_LIMIT, &count);
283+
if (status == compresskit::FrequencyTableReadStatus::TRUNCATED) {
280284
std::cerr << "Failed to read frequency table\n";
281285
return false;
282286
}
283-
if (count != SYMBOL_LIMIT) {
287+
if (status == compresskit::FrequencyTableReadStatus::BAD_COUNT) {
284288
std::cerr << "Bad frequency table size: " << count << "\n";
285289
return false;
286290
}
287-
freq.assign(count, 0);
288-
in.read(reinterpret_cast<char*>(freq.data()), freq.size() * sizeof(uint32_t));
289-
if (!in) {
290-
std::cerr << "Failed to read frequency table\n";
291-
return false;
292-
}
293291
return true;
294292
}
295293

@@ -305,8 +303,10 @@ static bool compress_file(const std::string& input_path, const std::string& outp
305303
}
306304
}
307305
}
308-
309306
std::vector<uint32_t> freq = build_frequencies_from_file(input_path);
307+
if (freq.empty()) {
308+
return false;
309+
}
310310
std::vector<uint32_t> cumulative = build_cumulative(freq);
311311

312312
std::ifstream in(input_path, std::ios::binary);

algorithms/huffman/cpp/main.cpp

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <vector>
88

99
#include "compresskit/buffer_api.hpp"
10+
#include "compresskit/frequency_table.hpp"
1011

1112
class BitWriter {
1213
public:
@@ -179,40 +180,37 @@ static std::vector<uint32_t> build_frequencies_from_file(const std::string& inpu
179180
if (!in) {
180181
return freq;
181182
}
182-
char c;
183-
while (in.get(c)) {
184-
unsigned char uc = static_cast<unsigned char>(c);
185-
freq[static_cast<uint32_t>(uc)]++;
183+
uint32_t overflow_symbol = 0;
184+
const auto status = compresskit::accumulate_frequencies(in, freq, &overflow_symbol);
185+
if (status == compresskit::FrequencyCountStatus::IO_ERROR) {
186+
std::cerr << "Failed to read input file\n";
187+
freq.clear();
188+
return freq;
189+
}
190+
if (status == compresskit::FrequencyCountStatus::OVERFLOW) {
191+
std::cerr << "Frequency overflow for symbol " << overflow_symbol << "\n";
192+
freq.clear();
193+
return freq;
186194
}
187195
freq[EOF_SYMBOL] = 1;
188196
return freq;
189197
}
190198

191199
static void write_frequencies(std::ostream& out, const std::vector<uint32_t>& freq) {
192-
uint32_t count = static_cast<uint32_t>(freq.size());
193-
out.write(reinterpret_cast<const char*>(&count), sizeof(count));
194-
for (uint32_t v : freq) {
195-
out.write(reinterpret_cast<const char*>(&v), sizeof(v));
196-
}
200+
compresskit::write_frequency_table(out, freq);
197201
}
198202

199203
static bool read_frequencies(std::istream& in, std::vector<uint32_t>& freq) {
200204
uint32_t count = 0;
201-
in.read(reinterpret_cast<char*>(&count), sizeof(count));
202-
if (!in) {
205+
const auto status = compresskit::read_frequency_table(in, freq, SYMBOL_LIMIT, &count);
206+
if (status == compresskit::FrequencyTableReadStatus::TRUNCATED) {
203207
std::cerr << "Failed to read frequency table\n";
204208
return false;
205209
}
206-
if (count != SYMBOL_LIMIT) {
210+
if (status == compresskit::FrequencyTableReadStatus::BAD_COUNT) {
207211
std::cerr << "Bad frequency table size: " << count << "\n";
208212
return false;
209213
}
210-
freq.assign(count, 0);
211-
in.read(reinterpret_cast<char*>(freq.data()), freq.size() * sizeof(uint32_t));
212-
if (!in) {
213-
std::cerr << "Failed to read frequency table\n";
214-
return false;
215-
}
216214
return true;
217215
}
218216

@@ -230,6 +228,9 @@ static bool compress_file(const std::string& input_path, const std::string& outp
230228
}
231229

232230
std::vector<uint32_t> freq = build_frequencies_from_file(input_path);
231+
if (freq.empty()) {
232+
return false;
233+
}
233234
UniqueNode root(build_tree(freq)); // RAII: automatic cleanup
234235
std::vector<std::string> codes(SYMBOL_LIMIT);
235236
std::string prefix;
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
#include <istream>
5+
#include <ostream>
6+
#include <vector>
7+
8+
namespace compresskit {
9+
10+
enum class FrequencyTableReadStatus {
11+
OK = 0,
12+
TRUNCATED,
13+
BAD_COUNT,
14+
};
15+
16+
enum class FrequencyCountStatus {
17+
OK = 0,
18+
IO_ERROR,
19+
OVERFLOW,
20+
};
21+
22+
bool write_frequency_table(std::ostream& out, const std::vector<uint32_t>& freq);
23+
24+
FrequencyTableReadStatus read_frequency_table(std::istream& in, std::vector<uint32_t>& freq,
25+
uint32_t expected_count,
26+
uint32_t* actual_count = nullptr);
27+
28+
FrequencyCountStatus accumulate_frequencies(std::istream& in, std::vector<uint32_t>& freq,
29+
uint32_t* overflow_symbol = nullptr);
30+
31+
} // namespace compresskit
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#include "compresskit/frequency_table.hpp"
2+
3+
#include <array>
4+
#include <limits>
5+
6+
namespace compresskit {
7+
namespace {
8+
9+
bool write_u32_le(std::ostream& out, uint32_t value) {
10+
const std::array<char, 4> bytes = {
11+
static_cast<char>(value & 0xFFu),
12+
static_cast<char>((value >> 8) & 0xFFu),
13+
static_cast<char>((value >> 16) & 0xFFu),
14+
static_cast<char>((value >> 24) & 0xFFu),
15+
};
16+
out.write(bytes.data(), static_cast<std::streamsize>(bytes.size()));
17+
return static_cast<bool>(out);
18+
}
19+
20+
bool read_u32_le(std::istream& in, uint32_t& value) {
21+
std::array<unsigned char, 4> bytes{};
22+
in.read(reinterpret_cast<char*>(bytes.data()), static_cast<std::streamsize>(bytes.size()));
23+
if (!in) {
24+
return false;
25+
}
26+
value = static_cast<uint32_t>(bytes[0]) | (static_cast<uint32_t>(bytes[1]) << 8) |
27+
(static_cast<uint32_t>(bytes[2]) << 16) | (static_cast<uint32_t>(bytes[3]) << 24);
28+
return true;
29+
}
30+
31+
} // namespace
32+
33+
bool write_frequency_table(std::ostream& out, const std::vector<uint32_t>& freq) {
34+
if (!write_u32_le(out, static_cast<uint32_t>(freq.size()))) {
35+
return false;
36+
}
37+
for (uint32_t value : freq) {
38+
if (!write_u32_le(out, value)) {
39+
return false;
40+
}
41+
}
42+
return true;
43+
}
44+
45+
FrequencyTableReadStatus read_frequency_table(std::istream& in, std::vector<uint32_t>& freq,
46+
uint32_t expected_count, uint32_t* actual_count) {
47+
uint32_t count = 0;
48+
if (!read_u32_le(in, count)) {
49+
freq.clear();
50+
return FrequencyTableReadStatus::TRUNCATED;
51+
}
52+
if (actual_count) {
53+
*actual_count = count;
54+
}
55+
if (expected_count != 0 && count != expected_count) {
56+
freq.clear();
57+
return FrequencyTableReadStatus::BAD_COUNT;
58+
}
59+
60+
freq.assign(count, 0);
61+
for (uint32_t& value : freq) {
62+
if (!read_u32_le(in, value)) {
63+
freq.clear();
64+
return FrequencyTableReadStatus::TRUNCATED;
65+
}
66+
}
67+
return FrequencyTableReadStatus::OK;
68+
}
69+
70+
FrequencyCountStatus accumulate_frequencies(std::istream& in, std::vector<uint32_t>& freq,
71+
uint32_t* overflow_symbol) {
72+
std::array<unsigned char, 32 * 1024> buffer{};
73+
for (;;) {
74+
in.read(reinterpret_cast<char*>(buffer.data()), static_cast<std::streamsize>(buffer.size()));
75+
const std::streamsize read_count = in.gcount();
76+
for (std::streamsize i = 0; i < read_count; ++i) {
77+
const uint32_t symbol = static_cast<uint32_t>(buffer[static_cast<std::size_t>(i)]);
78+
if (freq[symbol] == std::numeric_limits<uint32_t>::max()) {
79+
if (overflow_symbol) {
80+
*overflow_symbol = symbol;
81+
}
82+
return FrequencyCountStatus::OVERFLOW;
83+
}
84+
++freq[symbol];
85+
}
86+
if (in.eof()) {
87+
return FrequencyCountStatus::OK;
88+
}
89+
if (!in) {
90+
return FrequencyCountStatus::IO_ERROR;
91+
}
92+
}
93+
}
94+
95+
} // namespace compresskit

algorithms/shared/cpp/tests/test_lifecycle.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
#include <algorithm>
22
#include <cassert>
33
#include <cstdint>
4+
#include <sstream>
45
#include <string>
56
#include <vector>
67

78
#include "compresskit/algorithms.hpp"
9+
#include "compresskit/frequency_table.hpp"
810

911
namespace {
1012

@@ -129,6 +131,65 @@ void test_decode_buffer_preserves_finish_retry_prefix() {
129131
assert(std::string(decoded.value.begin(), decoded.value.end()) == "uvwxyz");
130132
}
131133

134+
void test_write_frequency_table_uses_little_endian_layout() {
135+
std::ostringstream out(std::ios::binary);
136+
const std::vector<uint32_t> freq = {0x78563412u, 0x01020304u};
137+
138+
const bool ok = compresskit::write_frequency_table(out, freq);
139+
assert(ok);
140+
141+
const std::string bytes = out.str();
142+
const std::string expected(
143+
"\x02\x00\x00\x00"
144+
"\x12\x34\x56\x78"
145+
"\x04\x03\x02\x01",
146+
12);
147+
assert(bytes == expected);
148+
}
149+
150+
void test_read_frequency_table_decodes_little_endian_values() {
151+
const std::string bytes(
152+
"\x02\x00\x00\x00"
153+
"\x12\x34\x56\x78"
154+
"\x04\x03\x02\x01",
155+
12);
156+
std::istringstream in(bytes, std::ios::binary);
157+
std::vector<uint32_t> freq;
158+
uint32_t actual_count = 0;
159+
160+
const auto status = compresskit::read_frequency_table(in, freq, 2, &actual_count);
161+
162+
assert(status == compresskit::FrequencyTableReadStatus::OK);
163+
assert(actual_count == 2);
164+
assert((freq == std::vector<uint32_t>{0x78563412u, 0x01020304u}));
165+
}
166+
167+
void test_read_frequency_table_reports_bad_count() {
168+
const std::string bytes("\x02\x00\x00\x00", 4);
169+
std::istringstream in(bytes, std::ios::binary);
170+
std::vector<uint32_t> freq;
171+
uint32_t actual_count = 0;
172+
173+
const auto status = compresskit::read_frequency_table(in, freq, 3, &actual_count);
174+
175+
assert(status == compresskit::FrequencyTableReadStatus::BAD_COUNT);
176+
assert(actual_count == 2);
177+
assert(freq.empty());
178+
}
179+
180+
void test_accumulate_frequencies_reports_overflow() {
181+
std::vector<uint32_t> freq(257, 0);
182+
freq[0] = UINT32_MAX;
183+
std::istringstream in(std::string(1, '\0'), std::ios::binary);
184+
uint32_t overflow_symbol = UINT32_MAX;
185+
186+
const auto status = compresskit::accumulate_frequencies(in, freq, &overflow_symbol);
187+
188+
assert(status == compresskit::FrequencyCountStatus::OVERFLOW);
189+
assert(overflow_symbol == 0);
190+
assert(freq[0] == UINT32_MAX);
191+
}
192+
132193
} // namespace
133194

134195
int main() {
@@ -145,6 +206,10 @@ int main() {
145206

146207
test_encode_buffer_preserves_finish_retry_prefix();
147208
test_decode_buffer_preserves_finish_retry_prefix();
209+
test_write_frequency_table_uses_little_endian_layout();
210+
test_read_frequency_table_decodes_little_endian_values();
211+
test_read_frequency_table_reports_bad_count();
212+
test_accumulate_frequencies_reports_overflow();
148213

149214
return 0;
150215
}

0 commit comments

Comments
 (0)