Skip to content

Commit 1dea50e

Browse files
authored
Harden discriminant and proof bounds validation (#331)
1 parent ce12fc3 commit 1dea50e

6 files changed

Lines changed: 203 additions & 21 deletions

File tree

src/bqfc.c

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include "bqfc.h"
22

3-
#include <assert.h>
43
#include <limits.h>
54
#include <stdlib.h>
65
#include <string.h>
@@ -120,18 +119,25 @@ int bqfc_decompr(mpz_t out_a, mpz_t out_b, const mpz_t D, const struct qfb_c *c)
120119
return ret;
121120
}
122121

123-
static void bqfc_export(uint8_t *out_str, size_t *offset, size_t size,
122+
static int bqfc_export(uint8_t *out_str, size_t *offset, size_t size,
124123
const mpz_t n)
125124
{
126-
size_t bytes;
125+
size_t bytes = 0;
126+
const size_t bits = (size_t)mpz_sizeinbase(n, 2);
127+
const size_t needed_bytes = (bits + 7) / 8;
128+
129+
if (needed_bytes > size) {
130+
return -1;
131+
}
127132

128-
// mpz_export can overflow out_str if reduction bug but this should never happen
129133
mpz_export(&out_str[*offset], &bytes, -1, 1, 0, 0, n);
130-
if (bytes > size)
131-
gmp_printf("bqfc_export overflow offset %d size %d n %Zd\n", *offset, size, n);
134+
if (bytes > size) {
135+
return -1;
136+
}
132137
if (bytes < size)
133138
memset(&out_str[*offset + bytes], 0, size - bytes);
134139
*offset += size;
140+
return 0;
135141
}
136142

137143
enum BQFC_FLAG_BITS {
@@ -171,20 +177,29 @@ int bqfc_serialize_only(uint8_t *out_str, const struct qfb_c *c, size_t d_bits)
171177
{
172178
size_t offset, g_size;
173179

180+
if (d_bits == 0 || d_bits > BQFC_MAX_D_BITS)
181+
return -1;
174182
d_bits = (d_bits + 31) & ~(size_t)31;
183+
if (d_bits > BQFC_MAX_D_BITS)
184+
return -1;
175185

176186
out_str[0] = (uint8_t)c->b_sign << BQFC_B_SIGN_BIT;
177187
out_str[0] |= (mpz_sgn(c->t) < 0 ? 1 : 0) << BQFC_T_SIGN_BIT;
178188
g_size = (mpz_sizeinbase(c->g, 2) + 7) / 8 - 1;
179-
assert(g_size <= UCHAR_MAX);
189+
if (g_size > UCHAR_MAX)
190+
return -1;
180191
out_str[1] = (uint8_t)g_size;
181192
offset = 2;
182193

183-
bqfc_export(out_str, &offset, d_bits / 16 - g_size, c->a);
184-
bqfc_export(out_str, &offset, d_bits / 32 - g_size, c->t);
194+
if (bqfc_export(out_str, &offset, d_bits / 16 - g_size, c->a))
195+
return -1;
196+
if (bqfc_export(out_str, &offset, d_bits / 32 - g_size, c->t))
197+
return -1;
185198

186-
bqfc_export(out_str, &offset, g_size + 1, c->g);
187-
bqfc_export(out_str, &offset, g_size + 1, c->b0);
199+
if (bqfc_export(out_str, &offset, g_size + 1, c->g))
200+
return -1;
201+
if (bqfc_export(out_str, &offset, g_size + 1, c->b0))
202+
return -1;
188203

189204
return 0;
190205
}
@@ -193,7 +208,11 @@ int bqfc_deserialize_only(struct qfb_c *out_c, const uint8_t *str, size_t d_bits
193208
{
194209
size_t offset, bytes, g_size;
195210

211+
if (d_bits == 0 || d_bits > BQFC_MAX_D_BITS)
212+
return -1;
196213
d_bits = (d_bits + 31) & ~(size_t)31;
214+
if (d_bits > BQFC_MAX_D_BITS)
215+
return -1;
197216

198217
g_size = str[1];
199218
if (g_size >= d_bits / 32)
@@ -225,8 +244,11 @@ int bqfc_deserialize_only(struct qfb_c *out_c, const uint8_t *str, size_t d_bits
225244

226245
int bqfc_get_compr_size(size_t d_bits)
227246
{
247+
if (d_bits == 0 || d_bits > BQFC_MAX_D_BITS)
248+
return -1;
228249
size_t size = (d_bits + 31) / 32 * 3 + 4;
229-
assert(size <= INT_MAX);
250+
if (size > INT_MAX)
251+
return -1;
230252
return (int)size;
231253
}
232254

@@ -235,6 +257,8 @@ int bqfc_serialize(uint8_t *out_str, mpz_t a, mpz_t b, size_t d_bits)
235257
struct qfb_c f_c;
236258
int ret;
237259
int valid_size = bqfc_get_compr_size(d_bits);
260+
if (valid_size <= 0 || valid_size > BQFC_FORM_SIZE)
261+
return -1;
238262

239263
if (!mpz_cmp_ui(b, 1) && mpz_cmp_ui(a, 2) <= 0) {
240264
out_str[0] = !mpz_cmp_ui(a, 2) ? BQFC_IS_GEN : BQFC_IS_1;
@@ -271,6 +295,8 @@ int bqfc_deserialize(mpz_t out_a, mpz_t out_b, const mpz_t D, const uint8_t *str
271295
struct qfb_c f_c;
272296
int ret;
273297

298+
if (d_bits == 0 || d_bits > BQFC_MAX_D_BITS)
299+
return -1;
274300
if (size != BQFC_FORM_SIZE)
275301
return -1;
276302

src/create_discriminant.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ integer CreateDiscriminant(std::vector<uint8_t>& seed, int length = 1024) {
1515
}
1616

1717
// Check 2: Validate upper bound (optional but recommended)
18-
const int MAX_DISCRIMINANT_SIZE_BITS = 16384;
18+
const int MAX_DISCRIMINANT_SIZE_BITS = BQFC_MAX_D_BITS;
1919
if (length > MAX_DISCRIMINANT_SIZE_BITS) {
2020
throw std::invalid_argument(
2121
"discriminant_size_bits exceeds maximum allowed value"
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#include "verifier.h"
2+
3+
#include <gtest/gtest.h>
4+
5+
#include <cstdint>
6+
#include <string>
7+
#include <vector>
8+
9+
namespace {
10+
11+
std::vector<uint8_t> db_hex_to_bytes(const std::string& hex) {
12+
EXPECT_EQ(hex.size() % 2, 0U);
13+
std::vector<uint8_t> out;
14+
out.reserve(hex.size() / 2);
15+
for (size_t i = 0; i < hex.size(); i += 2) {
16+
out.push_back(static_cast<uint8_t>(std::stoul(hex.substr(i, 2), nullptr, 16)));
17+
}
18+
return out;
19+
}
20+
21+
std::vector<uint8_t> db_get_fixture_challenge() {
22+
return db_hex_to_bytes("9104c5b5e45d48f374efa0488fe6a617790e9aecb3c9cddec06809b09f45ce9b");
23+
}
24+
25+
std::vector<uint8_t> db_get_fixture_x() {
26+
return db_hex_to_bytes("08000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000");
27+
}
28+
29+
std::vector<uint8_t> db_get_fixture_proof_blob() {
30+
return db_hex_to_bytes(
31+
"0200553bf0f382fc65a94f20afad5dbce2c1ee8ba3bf93053559ac9960c8fd80ac2222e9b649701a4141a4d8999f0dbfe0c39ea744096598a7528328e5199f0aa30aec8aae8ab5018bf1245329a8272ddff1afbd87ad2eaba1b7fd57bd25edc62e0b010000003f0ffcd0dc307a2aa4678bafba661c77d176ef23afc86e7ea9f4f9eac52b8e1850748019245ecc96547da9b731dc72cded5582a9b0c63e13fd42446c7b28b41d3ded1d0b666d5ddb5b29719e4ebe70969e67e42ddd8591eae60d83dbe619f1250400");
32+
}
33+
34+
} // namespace
35+
36+
TEST(DiscriminantBoundsRegressionTest, VerifyRejectsDiscSizeBitsAboveMaximum) {
37+
std::vector<uint8_t> challenge = db_get_fixture_challenge();
38+
std::vector<uint8_t> x = db_get_fixture_x();
39+
std::vector<uint8_t> proof_blob = db_get_fixture_proof_blob();
40+
const integer D = CreateDiscriminant(challenge, BQFC_MAX_D_BITS);
41+
42+
EXPECT_TRUE(CheckProofOfTimeNWesolowski(
43+
D,
44+
x.data(),
45+
proof_blob.data(),
46+
proof_blob.size(),
47+
129499136,
48+
BQFC_MAX_D_BITS,
49+
0));
50+
51+
EXPECT_FALSE(CheckProofOfTimeNWesolowski(
52+
D,
53+
x.data(),
54+
proof_blob.data(),
55+
proof_blob.size(),
56+
129499136,
57+
static_cast<uint64_t>(BQFC_MAX_D_BITS) + 1,
58+
0));
59+
}
60+
61+
TEST(DiscriminantBoundsRegressionTest, CreateDiscriminantAndVerifyRejectsDiscSizeBitsAboveMaximum) {
62+
std::vector<uint8_t> challenge = db_get_fixture_challenge();
63+
std::vector<uint8_t> x = db_get_fixture_x();
64+
std::vector<uint8_t> proof_blob = db_get_fixture_proof_blob();
65+
66+
EXPECT_FALSE(CreateDiscriminantAndCheckProofOfTimeNWesolowski(
67+
challenge,
68+
static_cast<uint32_t>(BQFC_MAX_D_BITS + 1),
69+
x.data(),
70+
proof_blob.data(),
71+
proof_blob.size(),
72+
129499136,
73+
0));
74+
}
75+
76+
TEST(DiscriminantBoundsRegressionTest, BqfcSerializationRejectsOversizedDiscriminantBits) {
77+
uint8_t serialized[BQFC_FORM_SIZE];
78+
mpz_t a;
79+
mpz_t b;
80+
mpz_init_set_ui(a, 1);
81+
mpz_init_set_ui(b, 1);
82+
83+
EXPECT_EQ(bqfc_serialize(serialized, a, b, static_cast<size_t>(BQFC_MAX_D_BITS) + 1), -1);
84+
85+
mpz_clear(a);
86+
mpz_clear(b);
87+
}
88+
89+
TEST(DiscriminantBoundsRegressionTest, BqfcDeserializationRejectsOversizedDiscriminantBits) {
90+
uint8_t serialized[BQFC_FORM_SIZE] = {0};
91+
mpz_t D;
92+
mpz_t out_a;
93+
mpz_t out_b;
94+
mpz_init_set_si(D, -23);
95+
mpz_init(out_a);
96+
mpz_init(out_b);
97+
98+
EXPECT_EQ(
99+
bqfc_deserialize(out_a, out_b, D, serialized, BQFC_FORM_SIZE, static_cast<size_t>(BQFC_MAX_D_BITS) + 1),
100+
-1);
101+
102+
mpz_clear(D);
103+
mpz_clear(out_a);
104+
mpz_clear(out_b);
105+
}

src/regression_unit_tests.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "checked_cast_test.cpp"
2+
#include "discriminant_bounds_regression_test.cpp"
23
#include "proof_deserialization_regression_test.cpp"
34
#include "prover_slow_regression_test.cpp"
45
#include "two_weso_callback_regression_test.cpp"

src/vdf_client.cpp

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
using boost::asio::ip::tcp;
88

99
std::mutex socket_mutex;
10+
namespace {
11+
constexpr int kIterationHeaderDigits = 2;
12+
constexpr int kMaxIterationDigits = 20;
13+
} // namespace
1014

1115
// Segments are 2^16, 2^18, ..., 2^30
1216
// Best case it'll be able to proof for up to 2^36 due to 64-wesolowski restriction.
@@ -101,23 +105,51 @@ void FinishSession(tcp::socket& sock) {
101105
char ack[5];
102106
memset(ack, 0x00, sizeof(ack));
103107
boost::asio::read(sock, boost::asio::buffer(ack, 3), error);
104-
assert (strncmp(ack, "ACK", 3) == 0);
108+
if (strncmp(ack, "ACK", 3) != 0) {
109+
throw std::runtime_error("Invalid stop ACK");
110+
}
105111
} catch (std::exception& e) {
106112
PrintInfo("Exception in thread: " + to_string(e.what()));
107113
}
108114
}
109115

110116
uint64_t ReadIteration(tcp::socket& sock) {
111117
boost::system::error_code error;
112-
char data[20];
113-
memset(data, 0, sizeof(data));
114-
boost::asio::read(sock, boost::asio::buffer(data, 2), error);
115-
int size = (data[0] - '0') * 10 + (data[1] - '0');
118+
char size_buf[kIterationHeaderDigits];
119+
memset(size_buf, 0, sizeof(size_buf));
120+
boost::asio::read(sock, boost::asio::buffer(size_buf, kIterationHeaderDigits), error);
121+
if (error) {
122+
throw std::runtime_error("Failed to read iteration size header");
123+
}
124+
if (size_buf[0] < '0' || size_buf[0] > '9' || size_buf[1] < '0' || size_buf[1] > '9') {
125+
throw std::runtime_error("Iteration size header must be decimal digits");
126+
}
127+
128+
int size = (size_buf[0] - '0') * 10 + (size_buf[1] - '0');
129+
if (size == 0) {
130+
return 0;
131+
}
132+
if (size > kMaxIterationDigits) {
133+
throw std::runtime_error("Invalid iteration size");
134+
}
135+
136+
char data[kMaxIterationDigits];
116137
memset(data, 0, sizeof(data));
117138
boost::asio::read(sock, boost::asio::buffer(data, size), error);
139+
if (error) {
140+
throw std::runtime_error("Failed to read iteration body");
141+
}
118142
uint64_t iters = 0;
119-
for (int i = 0; i < size; i++)
120-
iters = iters * 10 + data[i] - '0';
143+
for (int i = 0; i < size; i++) {
144+
if (data[i] < '0' || data[i] > '9') {
145+
throw std::runtime_error("Iteration body must be decimal digits");
146+
}
147+
const uint64_t digit = static_cast<uint64_t>(data[i] - '0');
148+
if (iters > (std::numeric_limits<uint64_t>::max() - digit) / 10) {
149+
throw std::runtime_error("Iteration value overflow");
150+
}
151+
iters = iters * 10 + digit;
152+
}
121153
return iters;
122154
}
123155

src/verifier.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,17 @@
1111

1212
const uint8_t DEFAULT_ELEMENT[] = { 0x08 };
1313

14+
inline bool IsDiscSizeBitsInRange(const uint64_t disc_size_bits)
15+
{
16+
return disc_size_bits > 0 && disc_size_bits <= static_cast<uint64_t>(BQFC_MAX_D_BITS);
17+
}
18+
19+
inline bool IsDiscriminantInRange(const integer& D)
20+
{
21+
const int d_bits = D.num_bits();
22+
return d_bits > 0 && static_cast<uint64_t>(d_bits) <= static_cast<uint64_t>(BQFC_MAX_D_BITS);
23+
}
24+
1425
int VerifyWesoSegment(integer &D, form x, form proof, integer &B, uint64_t iters, form &out_y)
1526
{
1627
PulmarkReducer reducer;
@@ -43,7 +54,8 @@ void VerifyWesolowskiProof(integer &D, form x, form y, form proof, uint64_t iter
4354

4455
bool CheckProofOfTimeNWesolowski(integer D, const uint8_t* x_s, const uint8_t* proof_blob, size_t proof_blob_len, uint64_t iterations, uint64 disc_size_bits, uint64_t depth)
4556
{
46-
(void)disc_size_bits;
57+
if (!IsDiscSizeBitsInRange(disc_size_bits) || !IsDiscriminantInRange(D))
58+
return false;
4759
const size_t form_size = BQFC_FORM_SIZE;
4860
const size_t segment_len = 8 + B_bytes + form_size;
4961
const size_t base_len = 2 * form_size;
@@ -125,6 +137,7 @@ bool CheckProofOfTimeNWesolowskiCommon(integer& D, form& x, const uint8_t* proof
125137
}
126138

127139
std::pair<bool, std::vector<uint8_t>> CheckProofOfTimeNWesolowskiWithB(integer D, integer B, const uint8_t* x_s, const uint8_t* proof_blob, size_t proof_blob_len, uint64_t iterations, uint64_t depth) {
140+
if (!IsDiscriminantInRange(D)) return {false, {}};
128141
const size_t form_size = BQFC_FORM_SIZE;
129142
const size_t segment_len = 8 + B_bytes + form_size;
130143
const uint64_t max_depth = static_cast<uint64_t>((std::numeric_limits<size_t>::max() - form_size) / segment_len);
@@ -150,6 +163,7 @@ std::pair<bool, std::vector<uint8_t>> CheckProofOfTimeNWesolowskiWithB(integer D
150163
}
151164

152165
integer GetBFromProof(integer D, const uint8_t* x_s, const uint8_t* proof_blob, size_t proof_blob_len, uint64_t iterations, uint64_t depth) {
166+
if (!IsDiscriminantInRange(D)) throw std::runtime_error("Invalid proof.");
153167
const size_t form_size = BQFC_FORM_SIZE;
154168
const size_t segment_len = 8 + B_bytes + form_size;
155169
const size_t base_len = 2 * form_size;
@@ -170,10 +184,14 @@ integer GetBFromProof(integer D, const uint8_t* x_s, const uint8_t* proof_blob,
170184

171185
bool CreateDiscriminantAndCheckProofOfTimeNWesolowski(std::vector<uint8_t> seed, uint32 disc_size_bits, const uint8_t* x_s, const uint8_t* proof_blob, size_t proof_blob_len, uint64_t iterations, uint64_t depth)
172186
{
187+
if (!IsDiscSizeBitsInRange(disc_size_bits))
188+
return false;
173189
integer D = CreateDiscriminant(
174190
seed,
175191
disc_size_bits
176192
);
193+
if (!IsDiscriminantInRange(D))
194+
return false;
177195

178196
return CheckProofOfTimeNWesolowski(D, x_s, proof_blob, proof_blob_len, iterations, disc_size_bits, depth);
179197
}

0 commit comments

Comments
 (0)