Skip to content

Commit 26cfa36

Browse files
authored
ML-DSA: Missing Private Key Validation Checks (#2874)
### Issue: `EVP_PKEY_pqdsa_new_raw_private_key()` accepts malformed keys with secret vectors `s1` and `s2` containing coefficients outside the valid range `[-η, η]`. These keys lead to undefined behavior, like producing signatures that fail verification. ### Description of changes: Adds the missing validation checks to `ml_dsa_pack_pk_from_sk()` in `crypto/fipsmodule/ml_dsa/ml_dsa_ref/packing.c`. It now rejects keys if `s1` or `s2` have coefficients exceeding `[-η, η]`. ### Call-outs: - With the addition of these validation checks, we should reject *all* invalid private keys. - Discovered via Wycheproof test vector: https://github.com/C2SP/wycheproof/blob/e3c37e9db0f85a762dfcef1642b046bd31090ca4/testvectors_v1/mldsa_44_sign_noseed_test.json#L626-L646 - **Upstream considerations**: While this change should ideally be made upstream in mldsa-native, we are landing this now since this code is in production and mldsa-native will take time to land. I will open an upstream PR soon to ensure consistency. - **Import protection**: If these checks get overridden during a future upstream import, the tests added in this PR will fail in CI, preventing that merge and ensuring the validation remains in place. ### Testing: - Adds test vector generation script `crypto/fipsmodule/ml_dsa/make_corrupted_key_tests.cc` - Adds the generated test vectors `crypto/evp_extra/mldsa_corrupted_key_tests.txt` - Adds a test `crypto/evp_extra/mldsa_test.cc` that uses these test vectors To run the test: ``` $ cd build $ ./crypto/crypto_test --gtest_filter="*MLDSATest.ExpandedKeyValidation*" ``` To (re-)generate the test vectors: ``` $ cd crypto/fipsmodule/ml_dsa $ make generate ``` By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license and the ISC license.
1 parent ca4a1ba commit 26cfa36

9 files changed

Lines changed: 2632 additions & 0 deletions

File tree

crypto/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,7 @@ if(BUILD_TESTING)
817817
evp_extra/p_pqdsa_test.cc
818818
evp_extra/p_kem_test.cc
819819
evp_extra/scrypt_test.cc
820+
evp_extra/mldsa_test.cc
820821
fips_callback_test.cc
821822
fipsmodule/aes/aes_test.cc
822823
fipsmodule/bn/bn_test.cc

crypto/evp_extra/mldsa_corrupted_key_tests.txt

Lines changed: 2310 additions & 0 deletions
Large diffs are not rendered by default.

crypto/evp_extra/mldsa_test.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#include <gtest/gtest.h>
2+
#include <openssl/evp.h>
3+
#include <openssl/obj.h>
4+
5+
#include "../test/file_test.h"
6+
7+
// ML-DSA parameter sets
8+
struct MLDSAParamSet {
9+
const char name[20];
10+
const int nid;
11+
};
12+
13+
static const struct MLDSAParamSet kMLDSAs[] = {{"MLDSA44", NID_MLDSA44},
14+
{"MLDSA65", NID_MLDSA65},
15+
{"MLDSA87", NID_MLDSA87}};
16+
17+
class MLDSATest : public testing::TestWithParam<MLDSAParamSet> {};
18+
19+
INSTANTIATE_TEST_SUITE_P(All, MLDSATest, testing::ValuesIn(kMLDSAs),
20+
[](const testing::TestParamInfo<MLDSAParamSet> &params)
21+
-> std::string { return params.param.name; });
22+
23+
TEST_P(MLDSATest, ExpandedKeyValidation) {
24+
const MLDSAParamSet ps = GetParam();
25+
26+
// This test verifies that we reject invalid extended keys, because they can
27+
// cause undefined behavior including producing unverifiable signatures.
28+
//
29+
// Test vectors are generated by make_corrupted_key_tests.cc which uses
30+
// internal ML-DSA functions to corrupt keys in specific ways.
31+
32+
FileTestGTest("crypto/evp_extra/mldsa_corrupted_key_tests.txt",
33+
[&](FileTest *t) {
34+
std::string param_set;
35+
ASSERT_TRUE(t->GetInstruction(&param_set, "ParamSet"));
36+
37+
// Skip test vectors for other parameter sets
38+
if (param_set != ps.name) {
39+
t->SkipCurrent();
40+
return;
41+
}
42+
43+
std::vector<uint8_t> corrupted_key;
44+
ASSERT_TRUE(t->GetBytes(&corrupted_key, "CorruptedKey"));
45+
46+
// Try to import the corrupted key - it should fail
47+
bssl::UniquePtr<EVP_PKEY> corrupted_pkey(
48+
EVP_PKEY_pqdsa_new_raw_private_key(
49+
ps.nid, corrupted_key.data(), corrupted_key.size()));
50+
51+
EXPECT_FALSE(corrupted_pkey.get())
52+
<< "Imported corrupted " << ps.name << " key";
53+
});
54+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
make_corrupted_key_tests

crypto/fipsmodule/ml_dsa/Makefile

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Makefile for generating ML-DSA test vectors
2+
#
3+
# Usage:
4+
# make # Build the generator
5+
# make generate # Generate test vectors
6+
# make clean # Clean build artifacts
7+
8+
# Paths relative to this directory
9+
BUILD_DIR = ../../../build
10+
CRYPTO_LIB = $(BUILD_DIR)/crypto/libcrypto.a
11+
OUTPUT_FILE = ../../evp_extra/mldsa_corrupted_key_tests.txt
12+
13+
# Compiler settings
14+
CXX = c++
15+
CXXFLAGS = -std=c++11 -Wall -I../../../include -I../../.. -I.
16+
LDFLAGS = $(CRYPTO_LIB)
17+
18+
.PHONY: all generate clean
19+
20+
all: make_corrupted_key_tests
21+
22+
make_corrupted_key_tests: make_corrupted_key_tests.cc
23+
@if [ ! -f "$(CRYPTO_LIB)" ]; then \
24+
echo "Error: libcrypto.a not found at $(CRYPTO_LIB)"; \
25+
echo "Please build aws-lc first in ./build"; \
26+
exit 1; \
27+
fi
28+
$(CXX) $(CXXFLAGS) -o $@ $< $(LDFLAGS)
29+
30+
generate: make_corrupted_key_tests
31+
@echo "Generating test vectors to $(OUTPUT_FILE)..."
32+
./make_corrupted_key_tests > $(OUTPUT_FILE)
33+
@echo "Successfully generated $(OUTPUT_FILE)"
34+
35+
clean:
36+
rm make_corrupted_key_tests
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
// Generates test vectors with intentionally corrupted ML-DSA private keys
2+
//
3+
// USAGE:
4+
// cd crypto/fipsmodule/ml_dsa
5+
// make generate
6+
//
7+
// This regenerates crypto/evp_extra/mldsa_corrupted_key_tests.txt
8+
9+
#include <openssl/evp.h>
10+
#include <openssl/obj.h>
11+
12+
#include <cassert>
13+
#include <cstddef>
14+
#include <cstdint>
15+
#include <cstdio>
16+
#include <cstring>
17+
#include <functional>
18+
#include <iostream>
19+
#include <vector>
20+
21+
// Need ML-DSA internal headers to manipulate the expanded private key
22+
extern "C" {
23+
#include "./ml_dsa_ref/packing.h"
24+
#include "./ml_dsa_ref/params.h"
25+
#include "./ml_dsa_ref/polyvec.h"
26+
}
27+
28+
static void PrintHex(const std::vector<uint8_t> &data) {
29+
for (uint8_t byte : data) {
30+
printf("%02x", byte);
31+
}
32+
}
33+
34+
struct MLDSAParamSet {
35+
const char name[20];
36+
const int nid;
37+
};
38+
39+
static const struct MLDSAParamSet kMLDSAs[] = {{"MLDSA44", NID_MLDSA44},
40+
{"MLDSA65", NID_MLDSA65},
41+
{"MLDSA87", NID_MLDSA87}};
42+
43+
44+
// Corruption function type: takes unpacked key components and corrupts them
45+
using CorruptionFn =
46+
std::function<void(polyvecl *s1, polyveck *s2, polyveck *t0)>;
47+
48+
49+
// Generates a corrupted private key with the provided corruption
50+
static bool GenerateCorruptedKey(const MLDSAParamSet &ps, ml_dsa_params *params,
51+
const std::vector<uint8_t> &honest_key_bytes,
52+
const CorruptionFn &corruption) {
53+
polyvecl s1;
54+
polyveck s2, t0;
55+
uint8_t rho[ML_DSA_SEEDBYTES];
56+
uint8_t tr[ML_DSA_TRBYTES];
57+
uint8_t key[ML_DSA_SEEDBYTES];
58+
59+
// Unpack the honest key
60+
ml_dsa_unpack_sk(params, rho, tr, key, &t0, &s1, &s2,
61+
honest_key_bytes.data());
62+
63+
// Apply the corruption
64+
corruption(&s1, &s2, &t0);
65+
66+
// Repack the corrupted key
67+
std::vector<uint8_t> corrupted_key(params->secret_key_bytes);
68+
ml_dsa_pack_sk(params, corrupted_key.data(), rho, tr, key, &t0, &s1, &s2);
69+
70+
// Verify the corrupted key differs from the honest key
71+
assert(std::memcmp(corrupted_key.data(), honest_key_bytes.data(),
72+
params->secret_key_bytes) != 0);
73+
74+
// Output the test vector
75+
printf("# corrupted private key with invalid s1 or s2, inconsistent\n");
76+
printf("CorruptedKey = ");
77+
PrintHex(corrupted_key);
78+
printf("\n\n");
79+
80+
// Create a consistent version by recomputing the public key and tr
81+
std::vector<uint8_t> consistent_key = corrupted_key;
82+
83+
// Recompute the public key. We cannot use ml_dsa_pack_pk_from_sk since we
84+
// fixed it to fail for invalid secret keys. Instead we adapt from
85+
// https://github.com/aws/aws-lc/blob/0336dd78a0f2623c1f9b209a98cd497026d9c779/crypto/fipsmodule/ml_dsa/ml_dsa_ref/packing.c#L7-L61
86+
ml_dsa_unpack_sk(params, rho, tr, key, &t0, &s1, &s2, consistent_key.data());
87+
polyvecl mat[ML_DSA_K_MAX];
88+
ml_dsa_polyvec_matrix_expand(params, mat, rho);
89+
ml_dsa_polyvecl_ntt(params, &s1);
90+
polyveck t1;
91+
ml_dsa_polyvec_matrix_pointwise_montgomery(params, &t1, mat, &s1);
92+
ml_dsa_polyveck_reduce(params, &t1);
93+
ml_dsa_polyveck_invntt_tomont(params, &t1);
94+
ml_dsa_polyveck_add(params, &t1, &t1, &s2);
95+
ml_dsa_polyveck_caddq(params, &t1);
96+
ml_dsa_polyveck_power2round(params, &t1, &t0, &t1);
97+
std::vector<uint8_t> consistent_pk(params->public_key_bytes);
98+
ml_dsa_pack_pk(params, consistent_pk.data(), rho, &t1);
99+
100+
// Recompute tr = SHAKE256(pk, 64)
101+
std::vector<uint8_t> new_tr(ML_DSA_TRBYTES);
102+
bssl::ScopedEVP_MD_CTX md_ctx;
103+
if (!EVP_DigestInit_ex(md_ctx.get(), EVP_shake256(), nullptr) ||
104+
!EVP_DigestUpdate(md_ctx.get(), consistent_pk.data(),
105+
params->public_key_bytes) ||
106+
!EVP_DigestFinalXOF(md_ctx.get(), new_tr.data(), new_tr.size())) {
107+
return false;
108+
}
109+
110+
// Repack the consistent corrupted key
111+
ml_dsa_pack_sk(params, consistent_key.data(), rho, new_tr.data(),
112+
consistent_pk.data(), &t0, &s1, &s2);
113+
114+
// Verify the consistent key differs from the inconsistent one
115+
assert(std::memcmp(consistent_key.data(), corrupted_key.data(),
116+
params->secret_key_bytes) != 0);
117+
118+
// Output the test vector
119+
printf("# corrupted private key with invalid s1 or s2, consistent\n");
120+
printf("CorruptedKey = ");
121+
PrintHex(consistent_key);
122+
printf("\n\n");
123+
124+
return true;
125+
}
126+
127+
static bool InitializeParams(int nid, ml_dsa_params *params) {
128+
if (nid == NID_MLDSA44) {
129+
ml_dsa_44_params_init(params);
130+
} else if (nid == NID_MLDSA65) {
131+
ml_dsa_65_params_init(params);
132+
} else if (nid == NID_MLDSA87) {
133+
ml_dsa_87_params_init(params);
134+
} else {
135+
std::cerr << "Unexpected NID: " << nid << "\n";
136+
return false;
137+
}
138+
return true;
139+
}
140+
141+
static bool GenerateHonestKey(const MLDSAParamSet &ps,
142+
const ml_dsa_params &params,
143+
std::vector<uint8_t> *honest_key_bytes) {
144+
// Generate an honest private key from a fixed seed
145+
const std::vector<uint8_t> seed(32, 0x42);
146+
bssl::UniquePtr<EVP_PKEY> honest_pkey(
147+
EVP_PKEY_pqdsa_new_raw_private_key(ps.nid, seed.data(), seed.size()));
148+
if (!honest_pkey) {
149+
std::cerr << "Failed to generate honest key for " << ps.name << "\n";
150+
return false;
151+
}
152+
153+
// Export the honest private key to bytes
154+
size_t key_len = params.secret_key_bytes;
155+
honest_key_bytes->resize(key_len);
156+
if (!EVP_PKEY_get_raw_private_key(honest_pkey.get(), honest_key_bytes->data(),
157+
&key_len)) {
158+
std::cerr << "Failed to export honest key for " << ps.name << "\n";
159+
return false;
160+
}
161+
return true;
162+
}
163+
164+
static std::vector<CorruptionFn> CreateCorruptionFunctions(
165+
const ml_dsa_params &params, int vec_index, int coeff_index) {
166+
return {
167+
// Corrupt s1 with eta + 1
168+
[&params, vec_index, coeff_index](polyvecl *s1, polyveck *, polyveck *) {
169+
s1->vec[vec_index].coeffs[coeff_index] = params.eta + 1;
170+
},
171+
// Corrupt s1 with -(eta + 1)
172+
[&params, vec_index, coeff_index](polyvecl *s1, polyveck *, polyveck *) {
173+
s1->vec[vec_index].coeffs[coeff_index] = -(params.eta + 1);
174+
},
175+
// Corrupt s2 with eta + 1
176+
[&params, vec_index, coeff_index](polyvecl *, polyveck *s2, polyveck *) {
177+
s2->vec[vec_index].coeffs[coeff_index] = params.eta + 1;
178+
},
179+
// Corrupt s2 with -(eta + 1)
180+
[&params, vec_index, coeff_index](polyvecl *, polyveck *s2, polyveck *) {
181+
s2->vec[vec_index].coeffs[coeff_index] = -(params.eta + 1);
182+
},
183+
};
184+
}
185+
186+
int main() {
187+
printf(
188+
"# Invalid ML-DSA extended private keys\n"
189+
"# This file was generated by "
190+
"crypto/fipsmodule/ml_dsa/make_corrupted_key_tests.cc\n\n");
191+
192+
for (const auto &ps : kMLDSAs) {
193+
printf("[ParamSet = %s]\n", ps.name);
194+
195+
ml_dsa_params params;
196+
if (!InitializeParams(ps.nid, &params)) {
197+
return 1;
198+
}
199+
200+
std::vector<uint8_t> honest_key_bytes;
201+
if (!GenerateHonestKey(ps, params, &honest_key_bytes)) {
202+
return 1;
203+
}
204+
205+
// Test coefficient indices: first, last, and some random ones
206+
const int coeff_indices[] = {0, 255, 127, 95, 42, 224};
207+
208+
for (int vec_index = 0; vec_index < params.l; vec_index++) {
209+
for (int coeff_index : coeff_indices) {
210+
const std::vector<CorruptionFn> corruptions =
211+
CreateCorruptionFunctions(params, vec_index, coeff_index);
212+
213+
for (const auto &corruption : corruptions) {
214+
if (!GenerateCorruptedKey(ps, &params, honest_key_bytes,
215+
corruption)) {
216+
return 1;
217+
}
218+
}
219+
}
220+
}
221+
}
222+
return 0;
223+
}

crypto/fipsmodule/ml_dsa/ml_dsa_ref/packing.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ int ml_dsa_pack_pk_from_sk(ml_dsa_params *params,
3232
//unpack sk
3333
ml_dsa_unpack_sk(params, rho, tr, key, &t0, &s1, &s2, sk);
3434

35+
// check s1 and s2 have coefficients in [-ETA, ETA]
36+
if (ml_dsa_polyvecl_chknorm(params, &s1, params->eta + 1) ||
37+
ml_dsa_polyveck_chknorm(params, &s2, params->eta + 1)) {
38+
return 1;
39+
}
40+
3541
// generate matrix A
3642
ml_dsa_polyvec_matrix_expand(params, mat, rho);
3743

782 KB
Binary file not shown.

sources.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ set(
5252
crypto/cipher_extra/test/nist_cavp/tdes_ecb.txt
5353
crypto/ecdh_extra/ecdh_tests.txt
5454
crypto/evp_extra/kbkdf_expand_tests.txt
55+
crypto/evp_extra/mldsa_corrupted_key_tests.txt
5556
crypto/evp_extra/sshkdf_tests.txt
5657
crypto/evp_extra/evp_tests.txt
5758
crypto/evp_extra/scrypt_tests.txt

0 commit comments

Comments
 (0)