From a24bfe17cec1a35d0ef1a6da881b6be6593d4946 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Sat, 4 Apr 2026 01:09:29 +1100 Subject: [PATCH] feat(mldsa): add mldsa --- mldsa/internal/byteorder/byteorder.go | 7 + mldsa/internal/constanttime/constant_time.go | 25 + mldsa/internal/cryptotest/stubs.go | 5 + mldsa/internal/fips140/mldsa/field.go | 782 ++++++++++++++++++ mldsa/internal/fips140/mldsa/field_test.go | 370 +++++++++ mldsa/internal/fips140/mldsa/mldsa.go | 783 +++++++++++++++++++ mldsa/internal/fips140/mldsa/mldsa_test.go | 345 ++++++++ mldsa/internal/fips140/mldsa/semiexpanded.go | 244 ++++++ mldsa/internal/fips140/mldsa/stubs.go | 10 + mldsa/internal/fips140/stubs.go | 5 + mldsa/internal/fips140test/mldsa_test.go | 729 +++++++++++++++++ mldsa/mldsa.go | 269 +++++++ mldsa/mldsa_test.go | 558 +++++++++++++ mldsa/mldsacrypto/mldsamu.go | 13 + 14 files changed, 4145 insertions(+) create mode 100644 mldsa/internal/byteorder/byteorder.go create mode 100644 mldsa/internal/constanttime/constant_time.go create mode 100644 mldsa/internal/cryptotest/stubs.go create mode 100644 mldsa/internal/fips140/mldsa/field.go create mode 100644 mldsa/internal/fips140/mldsa/field_test.go create mode 100644 mldsa/internal/fips140/mldsa/mldsa.go create mode 100644 mldsa/internal/fips140/mldsa/mldsa_test.go create mode 100644 mldsa/internal/fips140/mldsa/semiexpanded.go create mode 100644 mldsa/internal/fips140/mldsa/stubs.go create mode 100644 mldsa/internal/fips140/stubs.go create mode 100644 mldsa/internal/fips140test/mldsa_test.go create mode 100644 mldsa/mldsa.go create mode 100644 mldsa/mldsa_test.go create mode 100644 mldsa/mldsacrypto/mldsamu.go diff --git a/mldsa/internal/byteorder/byteorder.go b/mldsa/internal/byteorder/byteorder.go new file mode 100644 index 0000000..f458afa --- /dev/null +++ b/mldsa/internal/byteorder/byteorder.go @@ -0,0 +1,7 @@ +package byteorder + +import "encoding/binary" + +func LEPutUint16(b []byte, v uint16) { + binary.LittleEndian.PutUint16(b, v) +} diff --git a/mldsa/internal/constanttime/constant_time.go b/mldsa/internal/constanttime/constant_time.go new file mode 100644 index 0000000..24a7378 --- /dev/null +++ b/mldsa/internal/constanttime/constant_time.go @@ -0,0 +1,25 @@ +package constanttime + +import "crypto/subtle" + +// Select returns x if v == 1 and y if v == 0. +// Its behavior is undefined if v takes any other value. +func Select(v, x, y int) int { + return subtle.ConstantTimeSelect(v, x, y) +} + +// ByteEq returns 1 if x == y and 0 otherwise. +func ByteEq(x, y uint8) int { + return subtle.ConstantTimeByteEq(x, y) +} + +// Eq returns 1 if x == y and 0 otherwise. +func Eq(x, y int32) int { + return subtle.ConstantTimeEq(x, y) +} + +// LessOrEq returns 1 if x <= y and 0 otherwise. +// Its behavior is undefined if x or y are negative or > 2**31 - 1. +func LessOrEq(x, y int) int { + return subtle.ConstantTimeLessOrEq(x, y) +} diff --git a/mldsa/internal/cryptotest/stubs.go b/mldsa/internal/cryptotest/stubs.go new file mode 100644 index 0000000..ea03526 --- /dev/null +++ b/mldsa/internal/cryptotest/stubs.go @@ -0,0 +1,5 @@ +package cryptotest + +import "testing" + +func SkipTestAllocations(t *testing.T) {} diff --git a/mldsa/internal/fips140/mldsa/field.go b/mldsa/internal/fips140/mldsa/field.go new file mode 100644 index 0000000..f9dd9fa --- /dev/null +++ b/mldsa/internal/fips140/mldsa/field.go @@ -0,0 +1,782 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mldsa + +import ( + "crypto/sha3" + "errors" + "math/bits" + + "github.com/go-webauthn/x/mldsa/internal/constanttime" +) + +const ( + q = 8380417 // 2²³ - 2¹³ + 1 + R = 4294967296 // 2³² + RR = 2365951 // R² mod q, aka R in the Montgomery domain + qNegInv = 4236238847 // -q⁻¹ mod R (q * qNegInv ≡ -1 mod R) + one = 4193792 // R mod q, aka 1 in the Montgomery domain + minusOne = 4186625 // (q - 1) * R mod q, aka -1 in the Montgomery domain +) + +// fieldElement is an element n of ℤ_q in the Montgomery domain, represented as +// an integer x in [0, q) such that x ≡ n * R (mod q) where R = 2³². +type fieldElement uint32 + +var errUnreducedFieldElement = errors.New("mldsa: unreduced field element") + +// fieldToMontgomery checks that a value a is < q, and converts it to +// Montgomery form. +func fieldToMontgomery(a uint32) (fieldElement, error) { + if a >= q { + return 0, errUnreducedFieldElement + } + // a * R² * R⁻¹ ≡ a * R (mod q) + return fieldMontgomeryMul(fieldElement(a), RR), nil +} + +// fieldSubToMontgomery converts a difference a - b to Montgomery form. +// a and b must be < q. (This bound can probably be relaxed.) +func fieldSubToMontgomery(a, b uint32) fieldElement { + x := a - b + q + return fieldMontgomeryMul(fieldElement(x), RR) +} + +// fieldFromMontgomery converts a value a in Montgomery form back to +// standard representation. +func fieldFromMontgomery(a fieldElement) uint32 { + // (a * R) * 1 * R⁻¹ ≡ a (mod q) + return uint32(fieldMontgomeryReduce(uint64(a))) +} + +// fieldCenteredMod returns r mod± q, the value r reduced to the range +// [−(q−1)/2, (q−1)/2]. +func fieldCenteredMod(r fieldElement) int32 { + x := int32(fieldFromMontgomery(r)) + // x <= q / 2 ? x : x - q + return constantTimeSelectLessOrEqual(x, q/2, x, x-q) +} + +// fieldInfinityNorm returns the infinity norm ||r||∞ of r, or the absolute +// value of r centered around 0. +func fieldInfinityNorm(r fieldElement) uint32 { + x := int32(fieldFromMontgomery(r)) + // x <= q / 2 ? x : |x - q| + // |x - q| = -(x - q) = q - x because x < q => x - q < 0 + return uint32(constantTimeSelectLessOrEqual(x, q/2, x, q-x)) +} + +// fieldReduceOnce reduces a value a < 2q. +func fieldReduceOnce(a uint32) fieldElement { + x, b := bits.Sub64(uint64(a), uint64(q), 0) + return fieldElement(x + b*q) +} + +// fieldAdd returns a + b mod q. +func fieldAdd(a, b fieldElement) fieldElement { + x := uint32(a + b) + return fieldReduceOnce(x) +} + +// fieldSub returns a - b mod q. +func fieldSub(a, b fieldElement) fieldElement { + x := uint32(a - b + q) + return fieldReduceOnce(x) +} + +// fieldMontgomeryMul returns a * b * R⁻¹ mod q. +func fieldMontgomeryMul(a, b fieldElement) fieldElement { + x := uint64(a) * uint64(b) + return fieldMontgomeryReduce(x) +} + +// fieldMontgomeryReduce returns x * R⁻¹ mod q for x < q * R. +func fieldMontgomeryReduce(x uint64) fieldElement { + t := uint32(x) * qNegInv + u := (x + uint64(t)*q) >> 32 + return fieldReduceOnce(uint32(u)) +} + +// fieldMontgomeryMulSub returns a * (b - c). This operation is fused to save a +// fieldReduceOnce after the subtraction. +func fieldMontgomeryMulSub(a, b, c fieldElement) fieldElement { + x := uint64(a) * uint64(b-c+q) + return fieldMontgomeryReduce(x) +} + +// fieldMontgomeryAddMul returns a * b + c * d. This operation is fused to save +// a fieldReduceOnce and a fieldReduce. +func fieldMontgomeryAddMul(a, b, c, d fieldElement) fieldElement { + x := uint64(a) * uint64(b) + x += uint64(c) * uint64(d) + return fieldMontgomeryReduce(x) +} + +const n = 256 + +// ringElement is a polynomial, an element of R_q. +type ringElement [n]fieldElement + +// polyAdd adds two ringElements or nttElements. +func polyAdd[T ~[n]fieldElement](a, b T) (s T) { + for i := range s { + s[i] = fieldAdd(a[i], b[i]) + } + return s +} + +// polySub subtracts two ringElements or nttElements. +func polySub[T ~[n]fieldElement](a, b T) (s T) { + for i := range s { + s[i] = fieldSub(a[i], b[i]) + } + return s +} + +// nttElement is an NTT representation, an element of T_q. +type nttElement [n]fieldElement + +// zetas are the values ζ^BitRev₈(k) mod q for each index k, converted to the +// Montgomery domain. +var zetas = [256]fieldElement{4193792, 25847, 5771523, 7861508, 237124, 7602457, 7504169, 466468, 1826347, 2353451, 8021166, 6288512, 3119733, 5495562, 3111497, 2680103, 2725464, 1024112, 7300517, 3585928, 7830929, 7260833, 2619752, 6271868, 6262231, 4520680, 6980856, 5102745, 1757237, 8360995, 4010497, 280005, 2706023, 95776, 3077325, 3530437, 6718724, 4788269, 5842901, 3915439, 4519302, 5336701, 3574422, 5512770, 3539968, 8079950, 2348700, 7841118, 6681150, 6736599, 3505694, 4558682, 3507263, 6239768, 6779997, 3699596, 811944, 531354, 954230, 3881043, 3900724, 5823537, 2071892, 5582638, 4450022, 6851714, 4702672, 5339162, 6927966, 3475950, 2176455, 6795196, 7122806, 1939314, 4296819, 7380215, 5190273, 5223087, 4747489, 126922, 3412210, 7396998, 2147896, 2715295, 5412772, 4686924, 7969390, 5903370, 7709315, 7151892, 8357436, 7072248, 7998430, 1349076, 1852771, 6949987, 5037034, 264944, 508951, 3097992, 44288, 7280319, 904516, 3958618, 4656075, 8371839, 1653064, 5130689, 2389356, 8169440, 759969, 7063561, 189548, 4827145, 3159746, 6529015, 5971092, 8202977, 1315589, 1341330, 1285669, 6795489, 7567685, 6940675, 5361315, 4499357, 4751448, 3839961, 2091667, 3407706, 2316500, 3817976, 5037939, 2244091, 5933984, 4817955, 266997, 2434439, 7144689, 3513181, 4860065, 4621053, 7183191, 5187039, 900702, 1859098, 909542, 819034, 495491, 6767243, 8337157, 7857917, 7725090, 5257975, 2031748, 3207046, 4823422, 7855319, 7611795, 4784579, 342297, 286988, 5942594, 4108315, 3437287, 5038140, 1735879, 203044, 2842341, 2691481, 5790267, 1265009, 4055324, 1247620, 2486353, 1595974, 4613401, 1250494, 2635921, 4832145, 5386378, 1869119, 1903435, 7329447, 7047359, 1237275, 5062207, 6950192, 7929317, 1312455, 3306115, 6417775, 7100756, 1917081, 5834105, 7005614, 1500165, 777191, 2235880, 3406031, 7838005, 5548557, 6709241, 6533464, 5796124, 4656147, 594136, 4603424, 6366809, 2432395, 2454455, 8215696, 1957272, 3369112, 185531, 7173032, 5196991, 162844, 1616392, 3014001, 810149, 1652634, 4686184, 6581310, 5341501, 3523897, 3866901, 269760, 2213111, 7404533, 1717735, 472078, 7953734, 1723600, 6577327, 1910376, 6712985, 7276084, 8119771, 4546524, 5441381, 6144432, 7959518, 6094090, 183443, 7403526, 1612842, 4834730, 7826001, 3919660, 8332111, 7018208, 3937738, 1400424, 7534263, 1976782} + +// ntt maps a ringElement to its nttElement representation. +// +// It implements NTT, according to FIPS 203, Algorithm 9. +func ntt(f ringElement) nttElement { + var m uint8 + + for len := 128; len >= 8; len /= 2 { + for start := 0; start < 256; start += 2 * len { + m++ + zeta := zetas[m] + + // Bounds check elimination hint. + f, flen := f[start:start+len], f[start+len:start+len+len] + for j := 0; j < len; j += 2 { + t := fieldMontgomeryMul(zeta, flen[j]) + flen[j] = fieldSub(f[j], t) + f[j] = fieldAdd(f[j], t) + + // Unroll by 2 for performance. + t = fieldMontgomeryMul(zeta, flen[j+1]) + flen[j+1] = fieldSub(f[j+1], t) + f[j+1] = fieldAdd(f[j+1], t) + } + } + } + + // Unroll len = 4, 2, and 1. + for start := 0; start < 256; start += 8 { + m++ + zeta := zetas[m] + + t := fieldMontgomeryMul(zeta, f[start+4]) + f[start+4] = fieldSub(f[start], t) + f[start] = fieldAdd(f[start], t) + + t = fieldMontgomeryMul(zeta, f[start+5]) + f[start+5] = fieldSub(f[start+1], t) + f[start+1] = fieldAdd(f[start+1], t) + + t = fieldMontgomeryMul(zeta, f[start+6]) + f[start+6] = fieldSub(f[start+2], t) + f[start+2] = fieldAdd(f[start+2], t) + + t = fieldMontgomeryMul(zeta, f[start+7]) + f[start+7] = fieldSub(f[start+3], t) + f[start+3] = fieldAdd(f[start+3], t) + } + for start := 0; start < 256; start += 4 { + m++ + zeta := zetas[m] + + t := fieldMontgomeryMul(zeta, f[start+2]) + f[start+2] = fieldSub(f[start], t) + f[start] = fieldAdd(f[start], t) + + t = fieldMontgomeryMul(zeta, f[start+3]) + f[start+3] = fieldSub(f[start+1], t) + f[start+1] = fieldAdd(f[start+1], t) + } + for start := 0; start < 256; start += 2 { + m++ + zeta := zetas[m] + + t := fieldMontgomeryMul(zeta, f[start+1]) + f[start+1] = fieldSub(f[start], t) + f[start] = fieldAdd(f[start], t) + } + + return nttElement(f) +} + +// inverseNTT maps a nttElement back to the ringElement it represents. +// +// It implements NTT⁻¹, according to FIPS 203, Algorithm 10. +func inverseNTT(f nttElement) ringElement { + var m uint8 = 255 + + // Unroll len = 1, 2, and 4. + for start := 0; start < 256; start += 2 { + zeta := zetas[m] + m-- + + t := f[start] + f[start] = fieldAdd(t, f[start+1]) + f[start+1] = fieldMontgomeryMulSub(zeta, f[start+1], t) + } + for start := 0; start < 256; start += 4 { + zeta := zetas[m] + m-- + + t := f[start] + f[start] = fieldAdd(t, f[start+2]) + f[start+2] = fieldMontgomeryMulSub(zeta, f[start+2], t) + + t = f[start+1] + f[start+1] = fieldAdd(t, f[start+3]) + f[start+3] = fieldMontgomeryMulSub(zeta, f[start+3], t) + } + for start := 0; start < 256; start += 8 { + zeta := zetas[m] + m-- + + t := f[start] + f[start] = fieldAdd(t, f[start+4]) + f[start+4] = fieldMontgomeryMulSub(zeta, f[start+4], t) + + t = f[start+1] + f[start+1] = fieldAdd(t, f[start+5]) + f[start+5] = fieldMontgomeryMulSub(zeta, f[start+5], t) + + t = f[start+2] + f[start+2] = fieldAdd(t, f[start+6]) + f[start+6] = fieldMontgomeryMulSub(zeta, f[start+6], t) + + t = f[start+3] + f[start+3] = fieldAdd(t, f[start+7]) + f[start+7] = fieldMontgomeryMulSub(zeta, f[start+7], t) + } + + for len := 8; len < 256; len *= 2 { + for start := 0; start < 256; start += 2 * len { + zeta := zetas[m] + m-- + + // Bounds check elimination hint. + f, flen := f[start:start+len], f[start+len:start+len+len] + for j := 0; j < len; j += 2 { + t := f[j] + f[j] = fieldAdd(t, flen[j]) + // -z * (t - flen[j]) = z * (flen[j] - t) + flen[j] = fieldMontgomeryMulSub(zeta, flen[j], t) + + // Unroll by 2 for performance. + t = f[j+1] + f[j+1] = fieldAdd(t, flen[j+1]) + flen[j+1] = fieldMontgomeryMulSub(zeta, flen[j+1], t) + } + } + } + + for i := range f { + f[i] = fieldMontgomeryMul(f[i], 16382) // 16382 = 256⁻¹ * R mod q + } + return ringElement(f) +} + +// nttMul multiplies two nttElements. +func nttMul(a, b nttElement) (p nttElement) { + for i := range p { + p[i] = fieldMontgomeryMul(a[i], b[i]) + } + return p +} + +// sampleNTT samples an nttElement uniformly at random from the seed rho and the +// indices s and r. It implements Step 3 of ExpandA, RejNTTPoly, and +// CoeffFromThreeBytes from FIPS 204, passing in ρ, s, and r instead of ρ'. +func sampleNTT(rho []byte, s, r byte) nttElement { + G := sha3.NewSHAKE128() + G.Write(rho) + G.Write([]byte{s, r}) + + var a nttElement + var j int // index into a + var buf [168]byte // buffered reads from B, matching the rate of SHAKE-128 + off := len(buf) // index into buf, starts in a "buffer fully consumed" state + for j < n { + if off >= len(buf) { + G.Read(buf[:]) + off = 0 + } + v := uint32(buf[off]) | uint32(buf[off+1])<<8 | uint32(buf[off+2])<<16 + off += 3 + f, err := fieldToMontgomery(v & 0b01111111_11111111_11111111) // 23 bits + if err != nil { + continue + } + a[j] = f + j++ + } + return a +} + +// sampleBoundedPoly samples a ringElement with coefficients in [−η, η] from the +// seed rho and the index r. It implements RejBoundedPoly and CoeffFromHalfByte +// from FIPS 204, passing in ρ and r separately from ExpandS. +func sampleBoundedPoly(rho []byte, r byte, p parameters) ringElement { + H := sha3.NewSHAKE256() + H.Write(rho) + H.Write([]byte{r, 0}) // IntegerToBytes(r, 2) + + var a ringElement + var j int + var buf [136]byte // buffered reads from H, matching the rate of SHAKE-256 + off := len(buf) // index into buf, starts in a "buffer fully consumed" state + for { + if off >= len(buf) { + H.Read(buf[:]) + off = 0 + } + z0 := buf[off] & 0x0F + z1 := buf[off] >> 4 + off++ + coeff, ok := coeffFromHalfByte(z0, p) + if ok { + a[j] = coeff + j++ + } + if j >= len(a) { + break + } + coeff, ok = coeffFromHalfByte(z1, p) + if ok { + a[j] = coeff + j++ + } + if j >= len(a) { + break + } + } + return a +} + +// sampleInBall samples a ringElement with coefficients in {−1, 0, 1}, and τ +// non-zero coefficients. It is not constant-time. +func sampleInBall(rho []byte, p parameters) ringElement { + H := sha3.NewSHAKE256() + H.Write(rho) + s := make([]byte, 8) + H.Read(s) + + var c ringElement + for i := 256 - p.τ; i < 256; i++ { + j := make([]byte, 1) + H.Read(j) + for j[0] > byte(i) { + H.Read(j) + } + c[i] = c[j[0]] + // c[j] = (−1) ^ h[i+τ−256], where h are the bits in s in little-endian. + // That is, -1⁰ = 1 if the bit is 0, -1¹ = -1 if it is 1. + bitIdx := i + p.τ - 256 + bit := (s[bitIdx/8] >> (bitIdx % 8)) & 1 + if bit == 0 { + c[j[0]] = one + } else { + c[j[0]] = minusOne + } + } + + return c +} + +// coeffFromHalfByte implements CoeffFromHalfByte from FIPS 204. +// +// It maps a value in [0, 15] to a coefficient in [−η, η] +func coeffFromHalfByte(b byte, p parameters) (fieldElement, bool) { + if b > 15 { + panic("internal error: half-byte out of range") + } + switch p.η { + case 2: + // Return z = 2 − (b mod 5), which maps from + // + // b = ( 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 ) + // + // to + // + // b%5 = ( 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3, 2, 1, 0 ) + // + // to + // + // z = ( -2, -1, 0, 1, 2, -2, -1, 0, 1, 2, -2, -1, 0, 1, 2 ) + // + if b > 14 { + return 0, false + } + // Calculate b % 5 with Barrett reduction, to avoid a potentially + // variable-time division. + const barrettMultiplier = 0x3334 // ⌈2¹⁶ / 5⌉ + const barrettShift = 16 // log₂(2¹⁶) + quotient := (uint32(b) * barrettMultiplier) >> barrettShift + remainder := uint32(b) - quotient*5 + return fieldSubToMontgomery(2, remainder), true + case 4: + // Return z = 4 − b, which maps from + // + // b = ( 8, 7, 6, 5, 4, 3, 2, 1, 0 ) + // + // to + // + // z = ( −4, -3, -2, -1, 0, 1, 2, 3, 4 ) + // + if b > 8 { + return 0, false + } + return fieldSubToMontgomery(4, uint32(b)), true + default: + panic("internal error: unsupported η") + } +} + +// power2Round implements Power2Round from FIPS 204. +// +// It separates the bottom d = 13 bits of each 23-bit coefficient, rounding the +// high part based on the low part, and correcting the low part accordingly. +func power2Round(r fieldElement) (hi uint16, lo fieldElement) { + rr := fieldFromMontgomery(r) + // Add 2¹² - 1 to round up r1 by one if r0 > 2¹². + // r is at most 2²³ - 2¹³ + 1, so rr + (2¹² - 1) won't overflow 23 bits. + r1 := rr + 1<<12 - 1 + r1 >>= 13 + // r1 <= 2¹⁰ - 1 + // r1 * 2¹³ <= (2¹⁰ - 1) * 2¹³ = 2²³ - 2¹³ < q + r0 := fieldSubToMontgomery(rr, r1<<13) + return uint16(r1), r0 +} + +// highBits implements HighBits from FIPS 204. +func highBits(r ringElement, p parameters) [n]byte { + var w [n]byte + switch p.γ2 { + case 32: + for i := range n { + w[i] = highBits32(fieldFromMontgomery(r[i])) + } + case 88: + for i := range n { + w[i] = highBits88(fieldFromMontgomery(r[i])) + } + default: + panic("mldsa: internal error: unsupported γ2") + } + return w +} + +// useHint implements UseHint from FIPS 204. +// +// It is not constant-time. +func useHint(r ringElement, h [n]byte, p parameters) [n]byte { + var w [n]byte + switch p.γ2 { + case 32: + for i := range n { + w[i] = useHint32(r[i], h[i]) + } + case 88: + for i := range n { + w[i] = useHint88(r[i], h[i]) + } + default: + panic("mldsa: internal error: unsupported γ2") + } + return w +} + +// makeHint implements MakeHint from FIPS 204. +func makeHint(ct0, w, cs2 ringElement, p parameters) (h [n]byte, count1s int) { + switch p.γ2 { + case 32: + for i := range n { + h[i] = makeHint32(ct0[i], w[i], cs2[i]) + count1s += int(h[i]) + } + case 88: + for i := range n { + h[i] = makeHint88(ct0[i], w[i], cs2[i]) + count1s += int(h[i]) + } + default: + panic("mldsa: internal error: unsupported γ2") + } + return h, count1s +} + +// highBits32 implements HighBits from FIPS 204 for γ2 = (q - 1) / 32. +func highBits32(x uint32) byte { + // The implementation is based on the reference implementation and on + // BoringSSL. There are exhaustive tests in TestDecompose that compare it to + // a straightforward implementation of Decompose from the spec, so for our + // purposes it only has to work and be constant-time. + r1 := (x + 127) >> 7 + r1 = (r1*1025 + (1 << 21)) >> 22 + r1 &= 0b1111 + return byte(r1) +} + +// decompose32 implements Decompose from FIPS 204 for γ2 = (q - 1) / 32. +// +// r1 is in [0, 15]. +func decompose32(r fieldElement) (r1 byte, r0 int32) { + x := fieldFromMontgomery(r) + r1 = highBits32(x) + + // r - r1 * (2 * γ2) mod± q + r0 = int32(x) - int32(r1)*2*(q-1)/32 + r0 = constantTimeSelectLessOrEqual(q/2+1, r0, r0-q, r0) + + return r1, r0 +} + +// useHint32 implements UseHint from FIPS 204 for γ2 = (q - 1) / 32. +func useHint32(r fieldElement, hint byte) byte { + const m = 16 // (q − 1) / (2 * γ2) + r1, r0 := decompose32(r) + if hint == 1 { + if r0 > 0 { + r1 = (r1 + 1) % m + } else { + // Underflow is safe, because it operates modulo 256 (since the type + // is byte), which is a multiple of m. + r1 = (r1 - 1) % m + } + } + return r1 +} + +// makeHint32 implements MakeHint from FIPS 204 for γ2 = (q - 1) / 32. +func makeHint32(ct0, w, cs2 fieldElement) byte { + // v1 = HighBits(r + z) = HighBits(w - cs2 + ct0 - ct0) = HighBits(w - cs2) + rPlusZ := fieldSub(w, cs2) + v1 := highBits32(fieldFromMontgomery(rPlusZ)) + // r1 = HighBits(r) = HighBits(w - cs2 + ct0) + r1 := highBits32(fieldFromMontgomery(fieldAdd(rPlusZ, ct0))) + + return byte(constanttime.ByteEq(v1, r1) ^ 1) +} + +// highBits88 implements HighBits from FIPS 204 for γ2 = (q - 1) / 88. +func highBits88(x uint32) byte { + // Like highBits32, this is exhaustively tested in TestDecompose. + r1 := (x + 127) >> 7 + r1 = (r1*11275 + (1 << 23)) >> 24 + r1 = constantTimeSelectEqual(r1, 44, 0, r1) + return byte(r1) +} + +// decompose88 implements Decompose from FIPS 204 for γ2 = (q - 1) / 88. +// +// r1 is in [0, 43]. +func decompose88(r fieldElement) (r1 byte, r0 int32) { + x := fieldFromMontgomery(r) + r1 = highBits88(x) + + // r - r1 * (2 * γ2) mod± q + r0 = int32(x) - int32(r1)*2*(q-1)/88 + r0 = constantTimeSelectLessOrEqual(q/2+1, r0, r0-q, r0) + + return r1, r0 +} + +// useHint88 implements UseHint from FIPS 204 for γ2 = (q - 1) / 88. +func useHint88(r fieldElement, hint byte) byte { + const m = 44 // (q − 1) / (2 * γ2) + r1, r0 := decompose88(r) + if hint == 1 { + if r0 > 0 { + // (r1 + 1) mod m, for r1 in [0, m-1] + if r1 == m-1 { + r1 = 0 + } else { + r1++ + } + } else { + // (r1 - 1) % m, for r1 in [0, m-1] + if r1 == 0 { + r1 = m - 1 + } else { + r1-- + } + } + } + return r1 +} + +// makeHint88 implements MakeHint from FIPS 204 for γ2 = (q - 1) / 88. +func makeHint88(ct0, w, cs2 fieldElement) byte { + // Same as makeHint32 above. + rPlusZ := fieldSub(w, cs2) + v1 := highBits88(fieldFromMontgomery(rPlusZ)) + r1 := highBits88(fieldFromMontgomery(fieldAdd(rPlusZ, ct0))) + return byte(constanttime.ByteEq(v1, r1) ^ 1) +} + +// bitPack implements BitPack(r mod± q, γ₁-1, γ₁), which packs the centered +// coefficients of r into little-endian γ1+1-bit chunks. It appends to buf. +// +// It must only be applied to r with coefficients in [−γ₁+1, γ₁], as +// guaranteed by the rejection conditions in Sign. +func bitPack(b []byte, r ringElement, p parameters) []byte { + switch p.γ1 { + case 17: + return bitPack18(b, r) + case 19: + return bitPack20(b, r) + default: + panic("mldsa: internal error: unsupported γ1") + } +} + +// bitPack18 implements BitPack(r mod± q, 2¹⁷-1, 2¹⁷), which packs the centered +// coefficients of r into little-endian 18-bit chunks. It appends to buf. +// +// It must only be applied to r with coefficients in [−2¹⁷+1, 2¹⁷], as +// guaranteed by the rejection conditions in Sign. +func bitPack18(buf []byte, r ringElement) []byte { + out, v := sliceForAppend(buf, 18*n/8) + const b = 1 << 17 + for i := 0; i < n; i += 4 { + // b - [−2¹⁷+1, 2¹⁷] = [0, 2²⁸-1] + w0 := b - fieldCenteredMod(r[i]) + v[0] = byte(w0 << 0) + v[1] = byte(w0 >> 8) + v[2] = byte(w0 >> 16) + w1 := b - fieldCenteredMod(r[i+1]) + v[2] |= byte(w1 << 2) + v[3] = byte(w1 >> 6) + v[4] = byte(w1 >> 14) + w2 := b - fieldCenteredMod(r[i+2]) + v[4] |= byte(w2 << 4) + v[5] = byte(w2 >> 4) + v[6] = byte(w2 >> 12) + w3 := b - fieldCenteredMod(r[i+3]) + v[6] |= byte(w3 << 6) + v[7] = byte(w3 >> 2) + v[8] = byte(w3 >> 10) + v = v[4*18/8:] + } + return out +} + +// bitPack20 implements BitPack(r mod± q, 2¹⁹-1, 2¹⁹), which packs the centered +// coefficients of r into little-endian 20-bit chunks. It appends to buf. +// +// It must only be applied to r with coefficients in [−2¹⁹+1, 2¹⁹], as +// guaranteed by the rejection conditions in Sign. +func bitPack20(buf []byte, r ringElement) []byte { + out, v := sliceForAppend(buf, 20*n/8) + const b = 1 << 19 + for i := 0; i < n; i += 2 { + // b - [−2¹⁹+1, 2¹⁹] = [0, 2²⁰-1] + w0 := b - fieldCenteredMod(r[i]) + v[0] = byte(w0 << 0) + v[1] = byte(w0 >> 8) + v[2] = byte(w0 >> 16) + w1 := b - fieldCenteredMod(r[i+1]) + v[2] |= byte(w1 << 4) + v[3] = byte(w1 >> 4) + v[4] = byte(w1 >> 12) + v = v[2*20/8:] + } + return out +} + +// bitUnpack implements BitUnpack(v, 2^γ1-1, 2^γ1), which unpacks each γ1+1 bits +// in little-endian into a coefficient in [-2^γ1+1, 2^γ1]. +func bitUnpack(v []byte, p parameters) ringElement { + switch p.γ1 { + case 17: + return bitUnpack18(v) + case 19: + return bitUnpack20(v) + default: + panic("mldsa: internal error: unsupported γ1") + } +} + +// bitUnpack18 implements BitUnpack(v, 2¹⁷-1, 2¹⁷), which unpacks each 18 bits +// in little-endian into a coefficient in [-2¹⁷+1, 2¹⁷]. +func bitUnpack18(v []byte) ringElement { + if len(v) != 18*n/8 { + panic("mldsa: internal error: invalid bitUnpack18 input length") + } + const b = 1 << 17 + const mask18 = 1<<18 - 1 + var r ringElement + for i := 0; i < n; i += 4 { + w0 := uint32(v[0]) | uint32(v[1])<<8 | uint32(v[2])<<16 + r[i+0] = fieldSubToMontgomery(b, w0&mask18) + w1 := uint32(v[2])>>2 | uint32(v[3])<<6 | uint32(v[4])<<14 + r[i+1] = fieldSubToMontgomery(b, w1&mask18) + w2 := uint32(v[4])>>4 | uint32(v[5])<<4 | uint32(v[6])<<12 + r[i+2] = fieldSubToMontgomery(b, w2&mask18) + w3 := uint32(v[6])>>6 | uint32(v[7])<<2 | uint32(v[8])<<10 + r[i+3] = fieldSubToMontgomery(b, w3&mask18) + v = v[4*18/8:] + } + return r +} + +// bitUnpack20 implements BitUnpack(v, 2¹⁹-1, 2¹⁹), which unpacks each 20 bits +// in little-endian into a coefficient in [-2¹⁹+1, 2¹⁹]. +func bitUnpack20(v []byte) ringElement { + if len(v) != 20*n/8 { + panic("mldsa: internal error: invalid bitUnpack20 input length") + } + const b = 1 << 19 + const mask20 = 1<<20 - 1 + var r ringElement + for i := 0; i < n; i += 2 { + w0 := uint32(v[0]) | uint32(v[1])<<8 | uint32(v[2])<<16 + r[i+0] = fieldSubToMontgomery(b, w0&mask20) + w1 := uint32(v[2])>>4 | uint32(v[3])<<4 | uint32(v[4])<<12 + r[i+1] = fieldSubToMontgomery(b, w1&mask20) + v = v[2*20/8:] + } + return r +} + +// sliceForAppend takes a slice and a requested number of bytes. It returns a +// slice with the contents of the given slice followed by that many bytes and a +// second slice that aliases into it and contains only the extra bytes. If the +// original slice has sufficient capacity then no allocation is performed. +func sliceForAppend(in []byte, n int) (head, tail []byte) { + if total := len(in) + n; cap(in) >= total { + head = in[:total] + } else { + head = make([]byte, total) + copy(head, in) + } + tail = head[len(in):] + return +} + +// constantTimeSelectLessOrEqual returns yes if a <= b, no otherwise, in constant time. +func constantTimeSelectLessOrEqual(a, b, yes, no int32) int32 { + return int32(constanttime.Select(constanttime.LessOrEq(int(a), int(b)), int(yes), int(no))) +} + +// constantTimeSelectEqual returns yes if a == b, no otherwise, in constant time. +func constantTimeSelectEqual(a, b, yes, no uint32) uint32 { + return uint32(constanttime.Select(constanttime.Eq(int32(a), int32(b)), int(yes), int(no))) +} + +// constantTimeAbs returns the absolute value of x in constant time. +func constantTimeAbs(x int32) uint32 { + return uint32(constantTimeSelectLessOrEqual(0, x, x, -x)) +} diff --git a/mldsa/internal/fips140/mldsa/field_test.go b/mldsa/internal/fips140/mldsa/field_test.go new file mode 100644 index 0000000..41680f4 --- /dev/null +++ b/mldsa/internal/fips140/mldsa/field_test.go @@ -0,0 +1,370 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mldsa + +import ( + "math/big" + "testing" +) + +type interestingValue struct { + v uint32 + m fieldElement +} + +// q is large enough that we can't exhaustively test all q × q inputs, so when +// we have two inputs we test [0, q) on one side and a set of interesting +// values on the other side. +func interestingValues() []interestingValue { + if testing.Short() { + return []interestingValue{{v: q - 1, m: minusOne}} + } + var values []interestingValue + for _, v := range []uint32{ + 0, + 1, + 2, + 3, + q - 3, + q - 2, + q - 1, + q / 2, + (q + 1) / 2, + } { + m, _ := fieldToMontgomery(v) + values = append(values, interestingValue{v: v, m: m}) + // Also test values that have an interesting Montgomery representation. + values = append(values, interestingValue{ + v: fieldFromMontgomery(fieldElement(v)), m: fieldElement(v)}) + } + return values +} + +func TestToFromMontgomery(t *testing.T) { + for a := range uint32(q) { + m, err := fieldToMontgomery(a) + if err != nil { + t.Fatalf("fieldToMontgomery(%d) returned error: %v", a, err) + } + exp := fieldElement((uint64(a) * R) % q) + if m != exp { + t.Fatalf("fieldToMontgomery(%d) = %d, expected %d", a, m, exp) + } + got := fieldFromMontgomery(m) + if got != a { + t.Fatalf("fieldFromMontgomery(fieldToMontgomery(%d)) = %d, expected %d", a, got, a) + } + } +} + +func TestFieldAdd(t *testing.T) { + t.Parallel() + for _, a := range interestingValues() { + for b := range fieldElement(q) { + got := fieldAdd(a.m, b) + exp := (a.m + b) % q + if got != exp { + t.Fatalf("%d + %d = %d, expected %d", a, b, got, exp) + } + } + } +} + +func TestFieldSub(t *testing.T) { + t.Parallel() + for _, a := range interestingValues() { + for b := range fieldElement(q) { + got := fieldSub(a.m, b) + exp := (a.m + q - b) % q + if got != exp { + t.Fatalf("%d - %d = %d, expected %d", a, b, got, exp) + } + } + } +} + +func TestFieldSubToMontgomery(t *testing.T) { + t.Parallel() + for _, a := range interestingValues() { + for b := range uint32(q) { + got := fieldSubToMontgomery(a.v, b) + diff := (a.v + q - b) % q + exp := fieldElement((uint64(diff) * R) % q) + if got != exp { + t.Fatalf("fieldSubToMontgomery(%d, %d) = %d, expected %d", a.v, b, got, exp) + } + } + } +} + +func TestFieldReduceOnce(t *testing.T) { + t.Parallel() + for a := range uint32(2 * q) { + got := fieldReduceOnce(a) + var exp uint32 + if a < q { + exp = a + } else { + exp = a - q + } + if uint32(got) != exp { + t.Fatalf("fieldReduceOnce(%d) = %d, expected %d", a, got, exp) + } + } +} + +func TestFieldMul(t *testing.T) { + t.Parallel() + for _, a := range interestingValues() { + for b := range fieldElement(q) { + got := fieldFromMontgomery(fieldMontgomeryMul(a.m, b)) + exp := uint32((uint64(a.v) * uint64(fieldFromMontgomery(b))) % q) + if got != exp { + t.Fatalf("%d * %d = %d, expected %d", a, b, got, exp) + } + } + } +} + +func TestFieldToMontgomeryOverflow(t *testing.T) { + // fieldToMontgomery should reject inputs ≥ q. + inputs := []uint32{ + q, + q + 1, + q + 2, + 1<<23 - 1, + 1 << 23, + q + 1<<23, + q + 1<<31, + ^uint32(0), + } + for _, in := range inputs { + if _, err := fieldToMontgomery(in); err == nil { + t.Fatalf("fieldToMontgomery(%d) did not return an error", in) + } + } +} + +func TestFieldMulSub(t *testing.T) { + for _, a := range interestingValues() { + for _, b := range interestingValues() { + for _, c := range interestingValues() { + got := fieldFromMontgomery(fieldMontgomeryMulSub(a.m, b.m, c.m)) + exp := uint32((uint64(a.v) * (uint64(b.v) + q - uint64(c.v))) % q) + if got != exp { + t.Fatalf("%d * (%d - %d) = %d, expected %d", a.v, b.v, c.v, got, exp) + } + } + } + } +} + +func TestFieldAddMul(t *testing.T) { + for _, a := range interestingValues() { + for _, b := range interestingValues() { + for _, c := range interestingValues() { + for _, d := range interestingValues() { + got := fieldFromMontgomery(fieldMontgomeryAddMul(a.m, b.m, c.m, d.m)) + exp := uint32((uint64(a.v)*uint64(b.v) + uint64(c.v)*uint64(d.v)) % q) + if got != exp { + t.Fatalf("%d + %d * %d = %d, expected %d", a.v, b.v, c.v, got, exp) + } + } + } + } + } +} + +func BitRev8(n uint8) uint8 { + var r uint8 + r |= n >> 7 & 0b0000_0001 + r |= n >> 5 & 0b0000_0010 + r |= n >> 3 & 0b0000_0100 + r |= n >> 1 & 0b0000_1000 + r |= n << 1 & 0b0001_0000 + r |= n << 3 & 0b0010_0000 + r |= n << 5 & 0b0100_0000 + r |= n << 7 & 0b1000_0000 + return r +} + +func CenteredMod(x, m uint32) int32 { + x = x % m + if x > m/2 { + return int32(x) - int32(m) + } + return int32(x) +} + +func reduceModQ(x int32) uint32 { + x %= q + if x < 0 { + return uint32(x + q) + } + return uint32(x) +} + +func TestCenteredMod(t *testing.T) { + for x := range uint32(q * 2) { + got := CenteredMod(uint32(x), q) + if reduceModQ(got) != (x % q) { + t.Fatalf("CenteredMod(%d) = %d, which is not congruent to %d mod %d", x, got, x, q) + } + } + + for x := range uint32(q) { + r, _ := fieldToMontgomery(x) + got := fieldCenteredMod(r) + exp := CenteredMod(x, q) + if got != exp { + t.Fatalf("fieldCenteredMod(%d) = %d, expected %d", x, got, exp) + } + } +} + +func TestInfinityNorm(t *testing.T) { + for x := range uint32(q) { + r, _ := fieldToMontgomery(x) + got := fieldInfinityNorm(r) + exp := CenteredMod(x, q) + if exp < 0 { + exp = -exp + } + if got != uint32(exp) { + t.Fatalf("fieldInfinityNorm(%d) = %d, expected %d", x, got, exp) + } + } +} + +func TestConstants(t *testing.T) { + if fieldFromMontgomery(one) != 1 { + t.Errorf("one constant incorrect") + } + if fieldFromMontgomery(minusOne) != q-1 { + t.Errorf("minusOne constant incorrect") + } + if fieldInfinityNorm(one) != 1 { + t.Errorf("one infinity norm incorrect") + } + if fieldInfinityNorm(minusOne) != 1 { + t.Errorf("minusOne infinity norm incorrect") + } + + if PublicKeySize44 != pubKeySize(params44) { + t.Errorf("PublicKeySize44 constant incorrect") + } + if PublicKeySize65 != pubKeySize(params65) { + t.Errorf("PublicKeySize65 constant incorrect") + } + if PublicKeySize87 != pubKeySize(params87) { + t.Errorf("PublicKeySize87 constant incorrect") + } + if SignatureSize44 != sigSize(params44) { + t.Errorf("SignatureSize44 constant incorrect") + } + if SignatureSize65 != sigSize(params65) { + t.Errorf("SignatureSize65 constant incorrect") + } + if SignatureSize87 != sigSize(params87) { + t.Errorf("SignatureSize87 constant incorrect") + } +} + +func TestPower2Round(t *testing.T) { + t.Parallel() + for x := range uint32(q) { + rr, _ := fieldToMontgomery(x) + t1, t0 := power2Round(rr) + + hi, err := fieldToMontgomery(uint32(t1) << 13) + if err != nil { + t.Fatalf("power2Round(%d): failed to convert high part to Montgomery: %v", x, err) + } + if r := fieldFromMontgomery(fieldAdd(hi, t0)); r != x { + t.Fatalf("power2Round(%d) = (%d, %d), which reconstructs to %d, expected %d", x, t1, t0, r, x) + } + } +} + +func SpecDecompose(rr fieldElement, p parameters) (R1 uint32, R0 int32) { + r := fieldFromMontgomery(rr) + if (q-1)%p.γ2 != 0 { + panic("mldsa: internal error: unsupported denγ2") + } + γ2 := (q - 1) / uint32(p.γ2) + r0 := CenteredMod(r, 2*γ2) + diff := int32(r) - r0 + if diff == q-1 { + r0 = r0 - 1 + return 0, r0 + } else { + if diff < 0 || uint32(diff)%γ2 != 0 { + panic("mldsa: internal error: invalid decomposition") + } + r1 := uint32(diff) / (2 * γ2) + return r1, r0 + } +} + +func TestDecompose(t *testing.T) { + t.Run("ML-DSA-44", func(t *testing.T) { + testDecompose(t, params44) + }) + t.Run("ML-DSA-65,87", func(t *testing.T) { + testDecompose(t, params65) + }) +} + +func testDecompose(t *testing.T, p parameters) { + t.Parallel() + for x := range uint32(q) { + rr, _ := fieldToMontgomery(x) + r1, r0 := SpecDecompose(rr, p) + + // Check that SpecDecompose is correct. + // r ≡ r1 * (2 * γ2) + r0 mod q + γ2 := (q - 1) / uint32(p.γ2) + reconstructed := reduceModQ(int32(r1*2*γ2) + r0) + if reconstructed != x { + t.Fatalf("SpecDecompose(%d) = (%d, %d), which reconstructs to %d, expected %d", x, r1, r0, reconstructed, x) + } + + var gotR1 byte + var gotR0 int32 + switch p.γ2 { + case 88: + gotR1, gotR0 = decompose88(rr) + if gotR1 > 43 { + t.Fatalf("decompose88(%d) returned r1 = %d, which is out of range", x, gotR1) + } + case 32: + gotR1, gotR0 = decompose32(rr) + if gotR1 > 15 { + t.Fatalf("decompose32(%d) returned r1 = %d, which is out of range", x, gotR1) + } + default: + t.Fatalf("unsupported denγ2: %d", p.γ2) + } + if uint32(gotR1) != r1 { + t.Fatalf("highBits(%d) = %d, expected %d", x, gotR1, r1) + } + if gotR0 != r0 { + t.Fatalf("lowBits(%d) = %d, expected %d", x, gotR0, r0) + } + } +} + +func TestZetas(t *testing.T) { + ζ := big.NewInt(1753) + q := big.NewInt(q) + for k, zeta := range zetas { + // ζ^BitRev₈(k) mod q + exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev8(uint8(k)))), q) + got := fieldFromMontgomery(zeta) + if big.NewInt(int64(got)).Cmp(exp) != 0 { + t.Errorf("zetas[%d] = %v, expected %v", k, got, exp) + } + } +} diff --git a/mldsa/internal/fips140/mldsa/mldsa.go b/mldsa/internal/fips140/mldsa/mldsa.go new file mode 100644 index 0000000..9ea53a7 --- /dev/null +++ b/mldsa/internal/fips140/mldsa/mldsa.go @@ -0,0 +1,783 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mldsa + +import ( + "bytes" + drbg "crypto/rand" + "crypto/sha3" + "crypto/subtle" + "errors" + + "github.com/go-webauthn/x/mldsa/internal/byteorder" + "github.com/go-webauthn/x/mldsa/internal/fips140" +) + +type parameters struct { + k, l int // dimensions of A + η int // bound for secret coefficients + γ1 int // log₂(γ₁), where [-γ₁+1, γ₁] is the bound of y + γ2 int // denominator of γ₂ = (q - 1) / γ2 + λ int // collison strength + τ int // number of non-zero coefficients in challenge + ω int // max number of hints in MakeHint +} + +var ( + params44 = parameters{k: 4, l: 4, η: 2, γ1: 17, γ2: 88, λ: 128, τ: 39, ω: 80} + params65 = parameters{k: 6, l: 5, η: 4, γ1: 19, γ2: 32, λ: 192, τ: 49, ω: 55} + params87 = parameters{k: 8, l: 7, η: 2, γ1: 19, γ2: 32, λ: 256, τ: 60, ω: 75} +) + +func pubKeySize(p parameters) int { + // ρ + k × n × 10-bit coefficients of t₁ + return 32 + p.k*n*10/8 +} + +func sigSize(p parameters) int { + // challenge + l × n × (γ₁+1)-bit coefficients of z + hint + return (p.λ / 4) + p.l*n*(p.γ1+1)/8 + p.ω + p.k +} + +const ( + PrivateKeySize = 32 + + PublicKeySize44 = 32 + 4*n*10/8 + PublicKeySize65 = 32 + 6*n*10/8 + PublicKeySize87 = 32 + 8*n*10/8 + + SignatureSize44 = 128/4 + 4*n*(17+1)/8 + 80 + 4 + SignatureSize65 = 192/4 + 5*n*(19+1)/8 + 55 + 6 + SignatureSize87 = 256/4 + 7*n*(19+1)/8 + 75 + 8 +) + +const maxK, maxL, maxλ, maxγ1 = 8, 7, 256, 19 +const maxPubKeySize = PublicKeySize87 + +type PrivateKey struct { + seed [32]byte + pub PublicKey + s1 [maxL]nttElement + s2 [maxK]nttElement + t0 [maxK]nttElement + k [32]byte +} + +func (priv *PrivateKey) Equal(x *PrivateKey) bool { + return priv.pub.p == x.pub.p && subtle.ConstantTimeCompare(priv.seed[:], x.seed[:]) == 1 +} + +func (priv *PrivateKey) Bytes() []byte { + seed := priv.seed + return seed[:] +} + +func (priv *PrivateKey) PublicKey() *PublicKey { + // Note that this is likely to keep the entire PrivateKey reachable for + // the lifetime of the PublicKey, which may be undesirable. + return &priv.pub +} + +type PublicKey struct { + raw [maxPubKeySize]byte + p parameters + a [maxK * maxL]nttElement + t1 [maxK]nttElement // NTT(t₁ ⋅ 2ᵈ) + tr [64]byte // public key hash +} + +func (pub *PublicKey) Equal(x *PublicKey) bool { + size := pubKeySize(pub.p) + return pub.p == x.p && subtle.ConstantTimeCompare(pub.raw[:size], x.raw[:size]) == 1 +} + +func (pub *PublicKey) Bytes() []byte { + size := pubKeySize(pub.p) + return bytes.Clone(pub.raw[:size]) +} + +func (pub *PublicKey) Parameters() string { + switch pub.p { + case params44: + return "ML-DSA-44" + case params65: + return "ML-DSA-65" + case params87: + return "ML-DSA-87" + default: + panic("mldsa: internal error: unknown parameters") + } +} + +func GenerateKey44() *PrivateKey { + fipsSelfTest() + fips140.RecordApproved() + var seed [32]byte + drbg.Read(seed[:]) + priv := newPrivateKey(&seed, params44) + fipsPCT(priv) + return priv +} + +func GenerateKey65() *PrivateKey { + fipsSelfTest() + fips140.RecordApproved() + var seed [32]byte + drbg.Read(seed[:]) + priv := newPrivateKey(&seed, params65) + fipsPCT(priv) + return priv +} + +func GenerateKey87() *PrivateKey { + fipsSelfTest() + fips140.RecordApproved() + var seed [32]byte + drbg.Read(seed[:]) + priv := newPrivateKey(&seed, params87) + fipsPCT(priv) + return priv +} + +var errInvalidSeedLength = errors.New("mldsa: invalid seed length") + +func NewPrivateKey44(seed []byte) (*PrivateKey, error) { + fipsSelfTest() + fips140.RecordApproved() + if len(seed) != 32 { + return nil, errInvalidSeedLength + } + return newPrivateKey((*[32]byte)(seed), params44), nil +} + +func NewPrivateKey65(seed []byte) (*PrivateKey, error) { + fipsSelfTest() + fips140.RecordApproved() + if len(seed) != 32 { + return nil, errInvalidSeedLength + } + return newPrivateKey((*[32]byte)(seed), params65), nil +} + +func NewPrivateKey87(seed []byte) (*PrivateKey, error) { + fipsSelfTest() + fips140.RecordApproved() + if len(seed) != 32 { + return nil, errInvalidSeedLength + } + return newPrivateKey((*[32]byte)(seed), params87), nil +} + +func newPrivateKey(seed *[32]byte, p parameters) *PrivateKey { + k, l := p.k, p.l + + priv := &PrivateKey{pub: PublicKey{p: p}} + priv.seed = *seed + + ξ := sha3.NewSHAKE256() + ξ.Write(seed[:]) + ξ.Write([]byte{byte(k), byte(l)}) + ρ, ρs := make([]byte, 32), make([]byte, 64) + ξ.Read(ρ) + ξ.Read(ρs) + ξ.Read(priv.k[:]) + + A := priv.pub.a[:k*l] + computeMatrixA(A, ρ, p) + + s1 := priv.s1[:l] + for r := range l { + s1[r] = ntt(sampleBoundedPoly(ρs, byte(r), p)) + } + s2 := priv.s2[:k] + for r := range k { + s2[r] = ntt(sampleBoundedPoly(ρs, byte(l+r), p)) + } + + // ˆt = Â ∘ ŝ₁ + ŝ₂ + tHat := make([]nttElement, k, maxK) + for i := range tHat { + tHat[i] = s2[i] + for j := range s1 { + tHat[i] = polyAdd(tHat[i], nttMul(A[i*l+j], s1[j])) + } + } + // t = NTT⁻¹(ˆt) + t := make([]ringElement, k, maxK) + for i := range tHat { + t[i] = inverseNTT(tHat[i]) + } + // (t₁, _) = Power2Round(t) + // (_, ˆt₀) = NTT(Power2Round(t)) + t1, t0 := make([][n]uint16, k, maxK), priv.t0[:k] + for i := range t { + var w ringElement + for j := range t[i] { + t1[i][j], w[j] = power2Round(t[i][j]) + } + t0[i] = ntt(w) + } + + // The computations below (and their storage in the PrivateKey struct) are + // not strictly necessary and could be deferred to PrivateKey.PublicKey(). + // That would require keeping or re-deriving ρ and t/t1, though. + + pk := pkEncode(priv.pub.raw[:0], ρ, t1, p) + priv.pub.tr = computePublicKeyHash(pk) + computeT1Hat(priv.pub.t1[:k], t1) // NTT(t₁ ⋅ 2ᵈ) + + return priv +} + +func computeMatrixA(A []nttElement, ρ []byte, p parameters) { + k, l := p.k, p.l + for r := range k { + for s := range l { + A[r*l+s] = sampleNTT(ρ, byte(s), byte(r)) + } + } +} + +func computePublicKeyHash(pk []byte) [64]byte { + H := sha3.NewSHAKE256() + H.Write(pk) + var tr [64]byte + H.Read(tr[:]) + return tr +} + +func computeT1Hat(t1Hat []nttElement, t1 [][n]uint16) { + for i := range t1 { + var w ringElement + for j := range t1[i] { + // t₁ <= 2¹⁰ - 1 + // t₁ ⋅ 2ᵈ <= 2ᵈ(2¹⁰ - 1) = 2²³ - 2¹³ < q = 2²³ - 2¹³ + 1 + z, _ := fieldToMontgomery(uint32(t1[i][j]) << 13) + w[j] = z + } + t1Hat[i] = ntt(w) + } +} + +func pkEncode(buf []byte, ρ []byte, t1 [][n]uint16, p parameters) []byte { + pk := append(buf, ρ...) + for _, w := range t1[:p.k] { + // Encode four at a time into 4 * 10 bits = 5 bytes. + for i := 0; i < n; i += 4 { + c0 := w[i] + c1 := w[i+1] + c2 := w[i+2] + c3 := w[i+3] + b0 := byte(c0 >> 0) + b1 := byte((c0 >> 8) | (c1 << 2)) + b2 := byte((c1 >> 6) | (c2 << 4)) + b3 := byte((c2 >> 4) | (c3 << 6)) + b4 := byte(c3 >> 2) + pk = append(pk, b0, b1, b2, b3, b4) + } + } + return pk +} + +func pkDecode(pk []byte, t1 [][n]uint16, p parameters) (ρ []byte, err error) { + if len(pk) != pubKeySize(p) { + return nil, errInvalidPublicKeyLength + } + ρ, pk = pk[:32], pk[32:] + for r := range t1 { + // Decode four at a time from 4 * 10 bits = 5 bytes. + for i := 0; i < n; i += 4 { + b0, b1, b2, b3, b4 := pk[0], pk[1], pk[2], pk[3], pk[4] + t1[r][i+0] = uint16(b0>>0) | uint16(b1&0b0000_0011)<<8 + t1[r][i+1] = uint16(b1>>2) | uint16(b2&0b0000_1111)<<6 + t1[r][i+2] = uint16(b2>>4) | uint16(b3&0b0011_1111)<<4 + t1[r][i+3] = uint16(b3>>6) | uint16(b4&0b1111_1111)<<2 + pk = pk[5:] + } + } + return ρ, nil +} + +var errInvalidPublicKeyLength = errors.New("mldsa: invalid public key length") + +func NewPublicKey44(pk []byte) (*PublicKey, error) { + return newPublicKey(pk, params44) +} + +func NewPublicKey65(pk []byte) (*PublicKey, error) { + return newPublicKey(pk, params65) +} + +func NewPublicKey87(pk []byte) (*PublicKey, error) { + return newPublicKey(pk, params87) +} + +func newPublicKey(pk []byte, p parameters) (*PublicKey, error) { + k, l := p.k, p.l + + t1 := make([][n]uint16, k, maxK) + ρ, err := pkDecode(pk, t1, p) + if err != nil { + return nil, err + } + + pub := &PublicKey{p: p} + copy(pub.raw[:], pk) + computeMatrixA(pub.a[:k*l], ρ, p) + pub.tr = computePublicKeyHash(pk) + computeT1Hat(pub.t1[:k], t1) // NTT(t₁ ⋅ 2ᵈ) + + return pub, nil +} + +var ( + errContextTooLong = errors.New("mldsa: context too long") + errMessageHashLength = errors.New("mldsa: invalid message hash length") + errRandomLength = errors.New("mldsa: invalid random length") +) + +func Sign(priv *PrivateKey, msg []byte, context string) ([]byte, error) { + fipsSelfTest() + fips140.RecordApproved() + var random [32]byte + drbg.Read(random[:]) + μ, err := computeMessageHash(priv.pub.tr[:], msg, context) + if err != nil { + return nil, err + } + return signInternal(priv, &μ, &random), nil +} + +func SignDeterministic(priv *PrivateKey, msg []byte, context string) ([]byte, error) { + fipsSelfTest() + fips140.RecordApproved() + var random [32]byte + μ, err := computeMessageHash(priv.pub.tr[:], msg, context) + if err != nil { + return nil, err + } + return signInternal(priv, &μ, &random), nil +} + +func TestingOnlySignWithRandom(priv *PrivateKey, msg []byte, context string, random []byte) ([]byte, error) { + fipsSelfTest() + fips140.RecordApproved() + μ, err := computeMessageHash(priv.pub.tr[:], msg, context) + if err != nil { + return nil, err + } + if len(random) != 32 { + return nil, errRandomLength + } + return signInternal(priv, &μ, (*[32]byte)(random)), nil +} + +func SignExternalMu(priv *PrivateKey, μ []byte) ([]byte, error) { + fipsSelfTest() + fips140.RecordApproved() + var random [32]byte + drbg.Read(random[:]) + if len(μ) != 64 { + return nil, errMessageHashLength + } + return signInternal(priv, (*[64]byte)(μ), &random), nil +} + +func SignExternalMuDeterministic(priv *PrivateKey, μ []byte) ([]byte, error) { + fipsSelfTest() + fips140.RecordApproved() + var random [32]byte + if len(μ) != 64 { + return nil, errMessageHashLength + } + return signInternal(priv, (*[64]byte)(μ), &random), nil +} + +func TestingOnlySignExternalMuWithRandom(priv *PrivateKey, μ []byte, random []byte) ([]byte, error) { + fipsSelfTest() + fips140.RecordApproved() + if len(μ) != 64 { + return nil, errMessageHashLength + } + if len(random) != 32 { + return nil, errRandomLength + } + return signInternal(priv, (*[64]byte)(μ), (*[32]byte)(random)), nil +} + +func computeMessageHash(tr []byte, msg []byte, context string) ([64]byte, error) { + if len(context) > 255 { + return [64]byte{}, errContextTooLong + } + H := sha3.NewSHAKE256() + H.Write(tr) + H.Write([]byte{0}) // ML-DSA / HashML-DSA domain separator + H.Write([]byte{byte(len(context))}) + H.Write([]byte(context)) + H.Write(msg) + var μ [64]byte + H.Read(μ[:]) + return μ, nil +} + +func signInternal(priv *PrivateKey, μ *[64]byte, random *[32]byte) []byte { + p, k, l := priv.pub.p, priv.pub.p.k, priv.pub.p.l + A, s1, s2, t0 := priv.pub.a[:k*l], priv.s1[:l], priv.s2[:k], priv.t0[:k] + + β := p.τ * p.η + γ1 := uint32(1 << p.γ1) + γ1β := γ1 - uint32(β) + γ2 := (q - 1) / uint32(p.γ2) + γ2β := γ2 - uint32(β) + + H := sha3.NewSHAKE256() + H.Write(priv.k[:]) + H.Write(random[:]) + H.Write(μ[:]) + nonce := make([]byte, 64) + H.Read(nonce) + + κ := 0 +sign: + for { + // Main rejection sampling loop. Note that leaking rejected signatures + // leaks information about the private key. However, as explained in + // https://pq-crystals.org/dilithium/data/dilithium-specification-round3.pdf + // Section 5.5, we are free to leak rejected ch values, as well as which + // check causes the rejection and which coefficient failed the check + // (but not the value or sign of the coefficient). + + y := make([]ringElement, l, maxL) + for r := range y { + counter := make([]byte, 2) + byteorder.LEPutUint16(counter, uint16(κ)) + κ++ + + H.Reset() + H.Write(nonce) + H.Write(counter) + v := make([]byte, (p.γ1+1)*n/8, (maxγ1+1)*n/8) + H.Read(v) + + y[r] = bitUnpack(v, p) + } + + // w = NTT⁻¹(Â ∘ NTT(y)) + yHat := make([]nttElement, l, maxL) + for i := range y { + yHat[i] = ntt(y[i]) + } + w := make([]ringElement, k, maxK) + for i := range w { + var wHat nttElement + for j := range l { + wHat = polyAdd(wHat, nttMul(A[i*l+j], yHat[j])) + } + w[i] = inverseNTT(wHat) + } + + H.Reset() + H.Write(μ[:]) + for i := range w { + w1Encode(H, highBits(w[i], p), p) + } + ch := make([]byte, p.λ/4, maxλ/4) + H.Read(ch) + + // sampleInBall is not constant time, but see comment above about + // leaking rejected ch values being acceptable. + c := ntt(sampleInBall(ch, p)) + + cs1 := make([]ringElement, l, maxL) + for i := range cs1 { + cs1[i] = inverseNTT(nttMul(c, s1[i])) + } + cs2 := make([]ringElement, k, maxK) + for i := range cs2 { + cs2[i] = inverseNTT(nttMul(c, s2[i])) + } + + z := make([]ringElement, l, maxL) + for i := range y { + z[i] = polyAdd(y[i], cs1[i]) + + // Reject if ||z||∞ ≥ γ1 − β + if coefficientsExceedBound(z[i], γ1β) { + if testingOnlyRejectionReason != nil { + testingOnlyRejectionReason("z") + } + continue sign + } + } + + for i := range w { + r0 := polySub(w[i], cs2[i]) + + // Reject if ||LowBits(r0)||∞ ≥ γ2 − β + if lowBitsExceedBound(r0, γ2β, p) { + if testingOnlyRejectionReason != nil { + testingOnlyRejectionReason("r0") + } + continue sign + } + } + + ct0 := make([]ringElement, k, maxK) + for i := range ct0 { + ct0[i] = inverseNTT(nttMul(c, t0[i])) + + // Reject if ||ct0||∞ ≥ γ2 + if coefficientsExceedBound(ct0[i], γ2) { + if testingOnlyRejectionReason != nil { + testingOnlyRejectionReason("ct0") + } + continue sign + } + } + + count1s := 0 + h := make([][n]byte, k, maxK) + for i := range w { + var count int + h[i], count = makeHint(ct0[i], w[i], cs2[i], p) + count1s += count + } + // Reject if number of hints > ω + if count1s > p.ω { + if testingOnlyRejectionReason != nil { + testingOnlyRejectionReason("h") + } + continue sign + } + + return sigEncode(ch, z, h, p) + } +} + +// testingOnlyRejectionReason is set in tests, to ensure that all rejection +// paths are covered. If not nil, it is called with a string describing the +// reason for rejection: "z", "r0", "ct0", or "h". +var testingOnlyRejectionReason func(reason string) + +// w1Encode implements w1Encode from FIPS 204, writing directly into H. +func w1Encode(H *sha3.SHAKE, w [n]byte, p parameters) { + switch p.γ2 { + case 32: + // Coefficients are <= (q − 1)/(2γ2) − 1 = 15, four bits each. + buf := make([]byte, 4*n/8) + for i := 0; i < n; i += 2 { + b0 := w[i] + b1 := w[i+1] + buf[i/2] = b0 | b1<<4 + } + H.Write(buf) + case 88: + // Coefficients are <= (q − 1)/(2γ2) − 1 = 43, six bits each. + buf := make([]byte, 6*n/8) + for i := 0; i < n; i += 4 { + b0 := w[i] + b1 := w[i+1] + b2 := w[i+2] + b3 := w[i+3] + buf[3*i/4+0] = (b0 >> 0) | (b1 << 6) + buf[3*i/4+1] = (b1 >> 2) | (b2 << 4) + buf[3*i/4+2] = (b2 >> 4) | (b3 << 2) + } + H.Write(buf) + default: + panic("mldsa: internal error: unsupported γ2") + } +} + +func coefficientsExceedBound(w ringElement, bound uint32) bool { + // If this function appears in profiles, it might be possible to deduplicate + // the work of fieldFromMontgomery inside fieldInfinityNorm with the + // subsequent encoding of w. + for i := range w { + if fieldInfinityNorm(w[i]) >= bound { + return true + } + } + return false +} + +func lowBitsExceedBound(w ringElement, bound uint32, p parameters) bool { + switch p.γ2 { + case 32: + for i := range w { + _, r0 := decompose32(w[i]) + if constantTimeAbs(r0) >= bound { + return true + } + } + case 88: + for i := range w { + _, r0 := decompose88(w[i]) + if constantTimeAbs(r0) >= bound { + return true + } + } + default: + panic("mldsa: internal error: unsupported γ2") + } + return false +} + +var ( + errInvalidSignatureLength = errors.New("mldsa: invalid signature length") + errInvalidSignatureCoeffBounds = errors.New("mldsa: invalid signature") + errInvalidSignatureChallenge = errors.New("mldsa: invalid signature") + errInvalidSignatureHintLimits = errors.New("mldsa: invalid signature encoding") + errInvalidSignatureHintIndexOrder = errors.New("mldsa: invalid signature encoding") + errInvalidSignatureHintExtraIndices = errors.New("mldsa: invalid signature encoding") +) + +func Verify(pub *PublicKey, msg, sig []byte, context string) error { + fipsSelfTest() + fips140.RecordApproved() + μ, err := computeMessageHash(pub.tr[:], msg, context) + if err != nil { + return err + } + return verifyInternal(pub, &μ, sig) +} + +func VerifyExternalMu(pub *PublicKey, μ []byte, sig []byte) error { + fipsSelfTest() + fips140.RecordApproved() + if len(μ) != 64 { + return errMessageHashLength + } + return verifyInternal(pub, (*[64]byte)(μ), sig) +} + +func verifyInternal(pub *PublicKey, μ *[64]byte, sig []byte) error { + p, k, l := pub.p, pub.p.k, pub.p.l + t1, A := pub.t1[:k], pub.a[:k*l] + + β := p.τ * p.η + γ1 := uint32(1 << p.γ1) + γ1β := γ1 - uint32(β) + + z := make([]ringElement, l, maxL) + h := make([][n]byte, k, maxK) + ch, err := sigDecode(sig, z, h, p) + if err != nil { + return err + } + + c := ntt(sampleInBall(ch, p)) + + // w = Â ∘ NTT(z) − NTT(c) ∘ NTT(t₁ ⋅ 2ᵈ) + zHat := make([]nttElement, l, maxL) + for i := range zHat { + zHat[i] = ntt(z[i]) + } + w := make([]ringElement, k, maxK) + for i := range w { + var wHat nttElement + for j := range l { + wHat = polyAdd(wHat, nttMul(A[i*l+j], zHat[j])) + } + wHat = polySub(wHat, nttMul(c, t1[i])) + w[i] = inverseNTT(wHat) + } + + // Use hints h to compute w₁ from w(approx). + w1 := make([][n]byte, k, maxK) + for i := range w { + w1[i] = useHint(w[i], h[i], p) + } + + H := sha3.NewSHAKE256() + H.Write(μ[:]) + for i := range w { + w1Encode(H, w1[i], p) + } + computedCH := make([]byte, p.λ/4, maxλ/4) + H.Read(computedCH) + + for i := range z { + if coefficientsExceedBound(z[i], γ1β) { + return errInvalidSignatureCoeffBounds + } + } + + if !bytes.Equal(ch, computedCH) { + return errInvalidSignatureChallenge + } + + return nil +} + +func sigEncode(ch []byte, z []ringElement, h [][n]byte, p parameters) []byte { + sig := make([]byte, 0, sigSize(p)) + sig = append(sig, ch...) + for i := range z { + sig = bitPack(sig, z[i], p) + } + sig = hintEncode(sig, h, p) + return sig +} + +func sigDecode(sig []byte, z []ringElement, h [][n]byte, p parameters) (ch []byte, err error) { + if len(sig) != sigSize(p) { + return nil, errInvalidSignatureLength + } + ch, sig = sig[:p.λ/4], sig[p.λ/4:] + for i := range z { + length := (p.γ1 + 1) * n / 8 + z[i] = bitUnpack(sig[:length], p) + sig = sig[length:] + } + if err := hintDecode(sig, h, p); err != nil { + return nil, err + } + return ch, nil +} + +func hintEncode(buf []byte, h [][n]byte, p parameters) []byte { + ω, k := p.ω, p.k + out, y := sliceForAppend(buf, ω+k) + var idx byte + for i := range k { + for j := range n { + if h[i][j] != 0 { + y[idx] = byte(j) + idx++ + } + } + y[ω+i] = idx + } + return out +} + +func hintDecode(y []byte, h [][n]byte, p parameters) error { + ω, k := p.ω, p.k + if len(y) != ω+k { + return errors.New("mldsa: internal error: invalid signature hint length") + } + var idx byte + for i := range k { + limit := y[ω+i] + if limit < idx || limit > byte(ω) { + return errInvalidSignatureHintLimits + } + first := idx + for idx < limit { + if idx > first && y[idx-1] >= y[idx] { + return errInvalidSignatureHintIndexOrder + } + h[i][y[idx]] = 1 + idx++ + } + } + for i := idx; i < byte(ω); i++ { + if y[i] != 0 { + return errInvalidSignatureHintExtraIndices + } + } + return nil +} diff --git a/mldsa/internal/fips140/mldsa/mldsa_test.go b/mldsa/internal/fips140/mldsa/mldsa_test.go new file mode 100644 index 0000000..74b0ea4 --- /dev/null +++ b/mldsa/internal/fips140/mldsa/mldsa_test.go @@ -0,0 +1,345 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mldsa + +import ( + "bytes" + "crypto/sha256" + "crypto/sha3" + "encoding/hex" + "strings" + "testing" +) + +// Most tests are in crypto/internal/fips140test/mldsa_test.go, so they can +// apply to all FIPS 140-3 module versions. This file contains only tests that +// need access to the unexported symbol testingOnlyRejectionReason. + +func TestACVPRejectionKATs(t *testing.T) { + testCases := []struct { + name string + seed string // input to ML-DSA.KeyGen_internal + keyHash string // SHA2-256(pk || sk) + msg string // M' input to ML-DSA.Sign_internal + sigHash string // SHA2-256(sig) + newPrivateKey func([]byte) (*PrivateKey, error) + newPublicKey func([]byte) (*PublicKey, error) + }{ + // https://pages.nist.gov/ACVP/draft-celi-acvp-ml-dsa.html#table-1 + // ML-DSA Algorithm 7 ML-DSA.Sign_internal() Known Answer Tests for Rejection Cases + + { + "Path/ML-DSA-44/1", + "5C624FCC1862452452D0C665840D8237F43108E5499EDCDC108FBC49D596E4B7", + "AC825C59D8A4C453A2C4EFEA8395741CA404F3000E28D56B25D03BB402E5CB2F", + "951FDF5473A4CBA6D9E5B5DB7E79FB8173921BA5B13E9271401B8F907B8B7D5B", + "DCC71A421BC6FFAFB7DF0C7F6D018A19ADA154D1E2EE360ED533CECD5DC980AD", + NewPrivateKey44, NewPublicKey44, + }, + { + "Path/ML-DSA-44/2", + "836EABEDB4D2CD9BE6A4D957CF5EE6BF489304136864C55C2C5F01DA5047D18B", + "E1FF40D96E3552FAB531D1715084B7E38CCDBACC0A8AF94C30959FB4C7F5A445", + "199A0AB735E9004163DD02D319A61CFE81638E3BF47BB1E90E90D6E3EA545247", + "A2608BC27E60541D27B6A14F460D54A48C0298DCC3F45999F29047A3135C4941", + NewPrivateKey44, NewPublicKey44, + }, + { + "Path/ML-DSA-44/3", + "CA5A01E1EA6552CB5C9803462B94C2F1DC9D13BB17A6ACE510D157056A2C6114", + "A4652DC4A271095268DD84A5B0744DFDBE2E642E4D41FBC4329C2FBA534C0E13", + "8C8CACA88FFF52B9330510537B3701B3993F3726136A650F48F8604551550832", + "B4B142209137397DAD504CAED01D390ADAF49973D8D2414FC3457FB7AF775189", + NewPrivateKey44, NewPublicKey44, + }, + { + "Path/ML-DSA-44/4", + "9C005F1550B4F31855C6B92F978736733F37791CB39DD182D7BA5732BDC2483E", + "2485AA99345F1B334D4D94B610FBFFCCB626CBFD4E9FF0E1F6FC35093C423544", + "B744343F30F7FEE088998BA574E799F1BF3939C06C29BF9AC10F3588A57E21E2", + "5B80A60BAA480B9D0C7D2C05B50928C4BF6808DDA693642058A3EB77EAA768FC", + NewPrivateKey44, NewPublicKey44, + }, + { + "Path/ML-DSA-44/5", + "4FAB5485B009399E8AE6FC3D3EEFBFE8E09796E4477AABD5EB1CC908FA734DE3", + "CB56909A7CF3008A662DC635EDCB79DC151CA7ACBAE17B544384ABD91BBBC1E9", + "7CAB0FDCF4BEA5F039137478AA45C9C48EF96D906FC49F6E2F138111BF1B4A4E", + "6CC38D73D639682ABC556DC6DCF436DE24033091F34004F410FABC6887F77AB0", + NewPrivateKey44, NewPublicKey44, + }, + { + "Path/ML-DSA-65/1", + "464756A985E5DF03739D95DD309C1ED9C5B04254CC294E7E7EB9B9365EE15117", + "AE95EA0DAA80199E7B4A74EB5A1B1DC6C3805BD01D2FA78D7C4FBA8C255AA13D", + "491101BBA044DE6E44A63796C33CDA051BB05A60725B87AF4BA9DB940C03AC09", + "8E08EA0C8DB941685B9905A73B0B57BAD3500B1F73490480B24375B41230CC04", + NewPrivateKey65, NewPublicKey65, + }, + { + "Path/ML-DSA-65/2", + "235A48DB4CA7916B884F424A8586EFD517E87C64AECEC0FCE9A3CC212BA1522E", + "1AC58A909DB4D7BC2473AB5E24AF768279C76F86A82D448258E24EEA4EA6B713", + "F8CE85CB2EC474FFBF5A3FFAE029CE6F4526B8D597655067F97F438B81071E9B", + "AE9531A01738615B6D33C77B3FF618A86E101FDC4C8504681F0EDFA64511AD63", + NewPrivateKey65, NewPublicKey65, + }, + { + "Path/ML-DSA-65/3", + "E13131B705A760305FEFFEBFE99082E2691A444BBEFCC3EDF67D909886200207", + "B422093F95CC489C52F4FA2B8973A2FDDD44426D1D04D1AAEEFC8715D417181F", + "CD365512C7E61BBAA130800B37F3BB46AAF1BEEF3742EA8A9010A6DD4576ED0B", + "3C55E604DECA7B89A99305D7A391C35F66A17C1923F467675EC951C0948D21C9", + NewPrivateKey65, NewPublicKey65, + }, + { + "Path/ML-DSA-65/4", + "0A4793E040A4BC0D0F37643D12C1EA1F10648724609936C76E0EC83E37209E92", + "622D26D536D4D66CD94956B33A74E2E830ED265D25C34FF7C3E5243403146ADF", + "6D9C7A795E48D80A892CBF4D4558429787277E3806EB5D0BCE1640EEBBBF9AEC", + "3B141110B9F56540B2D49AACDE6399974A4EAC40621E367E68D4504F294DB21B", + NewPrivateKey65, NewPublicKey65, + }, + { + "Path/ML-DSA-65/5", + "F865B889E5022D54BABC81CA67E7EB39F1AC42F92CF5295C3DA5C9667DB1B924", + "45BC8EDD1A620C46E973E346844270721824D97888BC174281852D98B7E8F4A3", + "047AFAADBE020ED2D766DA85317DEDE80BE550545F0B21E3F555A990F8004258", + "56308A3578360C41356BA9C97D3240E01767FA76BBBA9FD0CC6CFA9ADD088DB9", + NewPrivateKey65, NewPublicKey65, + }, + { + "Path/ML-DSA-87/1", + "0D58219132746BE077DFE821E9F8FD87857B28AB91D6A567E312A73E2636032C", + "4D261270341A7AC6B66900DDC2B8AB34AB483C897410DDF3B2C072BDDA416434", + "3AA49EF72D010AEC19383BA1E83EC2DD3DCC207A96FFCEB9FFA269E3E3D66400", + "5049DC39045618B903C71595B3A3E07A731F95D37304623ACC98BCEF4258B4CA", + NewPrivateKey87, NewPublicKey87, + }, + { + "Path/ML-DSA-87/2", + "146C47AB9F88408EB76A813294D533B29D7E0FDA75DA5A4E7C69EB61EFEEBB78", + "05194438AF855B79DB8CCCCB647D6BA5C7AAF901BBD09D3B29395F0EA431D164", + "82C44F998A8D24F056084D0E80ECFD8434493385A284C69974923C270D397782", + "CFFC5988A351E14A3EE1282F042A143679C4503814296B27993949A7FF966F57", + NewPrivateKey87, NewPublicKey87, + }, + { + "Path/ML-DSA-87/3", + "049D9B0B646A2AC7F50B63CE5E4BFE44C9B87634F4FF6C14C513E388B8A1F808", + "AC8FE6B2FE26591B129EA536A9A001C785D8ACBDD9489F6E51469A156E9E635D", + "FEBC9F8AE159002BE1A11D395959DD7FC20718135690CDAA2BCFB5801C02AB89", + "FF4006089BDF7337E868F86DDF48F239D2A52EA1D0F686E0103BF19C3B571DB1", + NewPrivateKey87, NewPublicKey87, + }, + { + "Path/ML-DSA-87/4", + "9823DDDE446A8EA883DAD3AC6477F79839FDC2D2DEF2416BE0A8B71CFBC3F5C6", + "525010E307C4EA7667D54EE27007C219B01F4CF88DC3AB2DE8E9AAA59440A884", + "F7592C97C1A96A2F4053588F5CDAD4C50BF7C3752709854FA27779B445DD2BA2", + "FD7757602B83B0A67A314CD5BCC880E7AE47ACDF4D6AF98269028EFB486838F7", + NewPrivateKey87, NewPublicKey87, + }, + { + "Path/ML-DSA-87/5", + "AE213FE8589B414F53780D8B9B6837179967E13CB474C5AD365C043778D2BC90", + "D4988E91064E5DF6D867434D1DED16DCD8533E39E420DC2B4EB9E40A84146F7D", + "19C1913BA76FF04596BB7CC80FD825A5AEDEF5D5AD61CEDB5203E6D7EDB18877", + "23FE743EDD101970D499E7EB57A7AA245BAF417E851B260C55DD525A445F08DA", + NewPrivateKey87, NewPublicKey87, + }, + + // https://pages.nist.gov/ACVP/draft-celi-acvp-ml-dsa.html#table-2 + // ML-DSA Algorithm 7 ML-DSA.Sign_internal() Known Answer Tests for Number of Rejection Cases + + { + "Count/ML-DSA-44/77", + "090D97C1F4166EB32CA67C5FB564ACBE0735DB4AF4B8DB3A7C2CE7402357CA44", + "26D79E4068040E996BC9EB5034C20489C0AD38DC2FEC1918D0760C8621872408", + "E3838364B37F47EDFCA2B577B20B80C3CB51B9F56E0E4CDB7DF002C874039252", + "CD91150C610FF02DE1DD7049C309EFE800CE5C1BC2E5A32D752AB62C5BF5E16F", + NewPrivateKey44, NewPublicKey44, + }, + { + "Count/ML-DSA-44/100", + "CFC73D07A883543A804F770070861825143A62F2F97D05FCE00FD8B25D29A43F", + "89142AB26D6EB6C01FA3F189A9C877597740D685983F29BBDD3596648266AE0E", + "0960C13E9BA467A938450120CC96FF6F04B7E557C99A838619A48F9A38738AB8", + "B6296FFF0C1F23DE4906D58144B00A2DB13AD25E49B4B8573A62EFEECB544DD7", + NewPrivateKey44, NewPublicKey44, + }, + { + "Count/ML-DSA-65/64", + "26B605C78AC762FA1634C6F91DD117C4FBFF7F3A7E7781F0CC83B6281F04AD7F", + "5DA13E571DF80867A8F27E0FF81BE7252A1ABF89B3D6A03D4036AF643EFBB04B", + "C9B07E7DDC0274468F312F5C692A54AC73D1E34D8638E20A2CD3C788F27D4355", + "12A4637E3A833A5A2A46F6A991399E544B62A230B7AA82F7366840FF6A88DE61", + NewPrivateKey65, NewPublicKey65, + }, + { + "Count/ML-DSA-65/73", + "9191CF381BEE17475C011986EFB6AFB1EFA6997442FD33427353F1DA1AA39FC0", + "7930D4E52BA03B61DAA57743B39E291D824DC156356C6B1A8232574D5C8BDD08", + "E616E36E81AA1EC39262109421AE0DDDA5E3B5A8F4A252BCA27AE882538DF618", + "3D758ACE312433D780403B3D4273171FB93D008B395352142C6DC5173E517310", + NewPrivateKey65, NewPublicKey65, + }, + { + "Count/ML-DSA-65/66", + "516912C7B90A3DBE009B7478DBCAF0F5C5C9ED9699A20D0CA56CC516E5A444CD", + "0FD15951B93A4D19446B48D47D32D2CA2253FF43BB8CCCB34C07E5F1A3181B7A", + "9247CA75F9456226A0C783DABCC33FF5B4B489575ADED543E74B29B45F9C8EF2", + "E5CE267800EDF33588451050F9B4A5BF97030D045132A7E3ED9210E74028D23B", + NewPrivateKey65, NewPublicKey65, + }, + { + "Count/ML-DSA-65/65", + "D4B841F882D50AB9E590066BAFABA0F0D04D32641C0B978E54CCAA69A6E8D2C4", + "0039C128DDE6923EA08FF14F5C5C66DCB282B471FD1917DBEBE07C8C45B73F8A", + "175231657B0F3C7065947999467C342064F29BFAEB553E97561407D5560E3AEB", + "8830EA254AF2854BF67C2B907E2321C94FD6EFB2FDAA77669FC3A5C4426C57C9", + NewPrivateKey65, NewPublicKey65, + }, + { + "Count/ML-DSA-65/64", + "5492EB8D811072C030A30CC66B23A173059EBA0D4868CCB92FBE2510B4A5915F", + "573DCD99C86DAE81F6F80CB00AF40846028EA8F9FE63102FE4A78238BC7B660E", + "33D2753ED87D0003B44C1AF5F72EB931F559C6B4931AF7E249F65D3FA7613295", + "84D4AF50933D6E13D4332B86AF0692A66F5030AB01C2EAC4131A5EEBF78CE9E5", + NewPrivateKey65, NewPublicKey65, + }, + { + "Count/ML-DSA-87/64", + "B5C07ECEFE9E7C3B885FDEF032BDF9F807B4011E2DFE6806C088D2081631C8EB", + "5D22F4C40F6EEB96BB891DB15884ED4B0009EA02A24D9D1E9ADFC81C7A42EA7F", + "D1D5C2D167D6E62906790A5FEDF5A0A754CFAF47E6A11AEB93FB8C41934C31F8", + "54F0A9CB26F98B394A35918ECA6760EBD10753FC5CDBA8BE508873AD83538131", + NewPrivateKey87, NewPublicKey87, + }, + { + "Count/ML-DSA-87/65", + "E8FC3C9FAD711DDA2946334FBBD331468D6E9AB48EB86DCD03F300A17AEBC5E5", + "B6C4DC9B20CE5D0F445931EE316CF0676E806D1A6A98868881D060EA27CEB139", + "3B435F7A2CE431C7AB8EAE0991C5DAC610827C99D27803046FBC6C567D6B71F2", + "E337495F08773F14FB26A3E229B9B26D086644C7FDC300267F9DCDD5D78DB849", + NewPrivateKey87, NewPublicKey87, + }, + { + "Count/ML-DSA-87/64", + "151F80886D6CE8C3B428964FE02C40CA0C8EFFA100EE089E54D785344FCCF719", + "127972C33323FEFBF6B69C19E0C86F41558D9AB2B1A8AD6F39BD0A0245DC8D7E", + "C628CE94D2AA99AA50CF15B147D4F9A9C62A3D4612152DE0A502C377F472D614", + "99B552B21432544248BFF47AC8F24CB78DBB25C9683F3ADCB75614BED58A0358", + NewPrivateKey87, NewPublicKey87, + }, + { + "Count/ML-DSA-87/64", + "48BEFFB4C97E59E474E1906F39888BE5AE62F6A011C05EF6A6B8D1E54F2171B7", + "72DA77CF563CBB530129F60129AF989CA4036BA1058267BFBA34A2C70BE803C4", + "D2756A8FB4E47F796AF704ED0FC8C6E573D42DFAB443B329F00F8DB2FF12C465", + "E643914B8556D05360C65EB3E7A06BE7C398B82D49973EEFDC711E65B11EB5E8", + NewPrivateKey87, NewPublicKey87, + }, + { + "Count/ML-DSA-87/69", + "FE2DA9DD93A077FCB6452AC88D0A5762EB896BAAAC6CE7D01CB1370BA8322390", + "7422DBE3F476FFE41A4EFB33F3DDFD8B328029BA3050603866C36CFBC2EE4B87", + "A86B29ADF2300D2636E21D4A350CD18E55A254379C3659A7A95D8734CEC1F005", + "8D25818DD972FFF5B9E9B4CC534A95100A1340C1C81D1486A68939D340E0A58B", + NewPrivateKey87, NewPublicKey87, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + seed := fromHex(tc.seed) + priv, err := tc.newPrivateKey(seed) + if err != nil { + t.Fatalf("NewPrivateKey: %v", err) + } + + if strings.Contains(t.Name(), "/Path/") { + // For path coverage tests, check that we hit all rejection paths. + reached := map[string]bool{"z": false, "r0": false, "ct0": false, "h": false} + // The ct0 rejection is only reachable for ML-DSA-44. + if priv.PublicKey().Parameters() != "ML-DSA-44" { + delete(reached, "ct0") + } + testingOnlyRejectionReason = func(reason string) { + t.Log(reason, "rejection") + reached[reason] = true + } + t.Cleanup(func() { + testingOnlyRejectionReason = nil + }) + defer func() { + for reason, hit := range reached { + if !hit { + t.Errorf("Rejection path %q not hit", reason) + } + } + }() + } + + pk := priv.PublicKey().Bytes() + sk := TestingOnlyPrivateKeySemiExpandedBytes(priv) + keyHashGot := sha256.Sum256(append(pk, sk...)) + keyHashWant := fromHex(tc.keyHash) + + if !bytes.Equal(keyHashGot[:], keyHashWant) { + t.Errorf("Key hash mismatch:\n got: %X\n want: %X", keyHashGot, keyHashWant) + } + + pub, err := tc.newPublicKey(pk) + if err != nil { + t.Fatalf("NewPublicKey: %v", err) + } + if !pub.Equal(priv.PublicKey()) { + t.Errorf("Parsed public key not equal to original") + } + if *pub != *priv.PublicKey() { + t.Errorf("Parsed public key not identical to original") + } + + // The table provides a Sign_internal input (not actually formatted + // like one), which is part of the pre-image of μ. + M := fromHex(tc.msg) + H := sha3.NewSHAKE256() + tr := computePublicKeyHash(pk) + H.Write(tr[:]) + H.Write(M) + μ := make([]byte, 64) + H.Read(μ) + t.Logf("Computed μ: %x", μ) + sig, err := SignExternalMuDeterministic(priv, μ) + if err != nil { + t.Fatalf("SignExternalMuDeterministic: %v", err) + } + + sigHashGot := sha256.Sum256(sig) + sigHashWant := fromHex(tc.sigHash) + + if !bytes.Equal(sigHashGot[:], sigHashWant) { + t.Errorf("Signature hash mismatch:\n got: %X\n want: %X", sigHashGot, sigHashWant) + } + + if err := VerifyExternalMu(priv.PublicKey(), μ, sig); err != nil { + t.Errorf("Verify: %v", err) + } + wrong := make([]byte, len(μ)) + if err := VerifyExternalMu(priv.PublicKey(), wrong, sig); err == nil { + t.Errorf("Verify passed on wrong message") + } + }) + } +} + +func fromHex(s string) []byte { + b, err := hex.DecodeString(s) + if err != nil { + panic(err) + } + return b +} diff --git a/mldsa/internal/fips140/mldsa/semiexpanded.go b/mldsa/internal/fips140/mldsa/semiexpanded.go new file mode 100644 index 0000000..df4b443 --- /dev/null +++ b/mldsa/internal/fips140/mldsa/semiexpanded.go @@ -0,0 +1,244 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mldsa + +import ( + drbg "crypto/rand" + "errors" + "math/bits" +) + +// FIPS 204 defines a needless semi-expanded format for private keys. This is +// not a good format for key storage and exchange, because it is large and +// requires careful parsing to reject malformed keys. Seeds instead are just 32 +// bytes, are always valid, and always expand to valid keys in memory. It is +// *also* a poor in-memory format, because it defers computing the NTT of s1, +// s2, and t0 and the expansion of A until signing time, which is inefficient. +// For a hot second, it looked like we could have all agreed to only use seeds, +// but unfortunately OpenSSL and BouncyCastle lobbied hard against that during +// the WGLC of the LAMPS IETF working group. Also, ACVP tests provide and expect +// semi-expanded keys, so we implement them here for testing purposes. + +func semiExpandedPrivKeySize(p parameters) int { + k, l := p.k, p.l + ηBitlen := bits.Len(uint(p.η)) + 1 + // ρ + K + tr + l × n × η-bit coefficients of s₁ + + // k × n × η-bit coefficients of s₂ + k × n × 13-bit coefficients of t₀ + return 32 + 32 + 64 + l*n*ηBitlen/8 + k*n*ηBitlen/8 + k*n*13/8 +} + +// TestingOnlyNewPrivateKeyFromSemiExpanded creates a PrivateKey from a +// semi-expanded private key encoding, for testing purposes. It rejects +// inconsistent keys. +// +// [PrivateKey.Bytes] must NOT be called on the resulting key, as it will +// produce a random value. +func TestingOnlyNewPrivateKeyFromSemiExpanded(sk []byte) (*PrivateKey, error) { + var p parameters + switch len(sk) { + case semiExpandedPrivKeySize(params44): + p = params44 + case semiExpandedPrivKeySize(params65): + p = params65 + case semiExpandedPrivKeySize(params87): + p = params87 + default: + return nil, errors.New("mldsa: invalid semi-expanded private key size") + } + k, l := p.k, p.l + + ρ, K, tr, s1, s2, t0, err := skDecode(sk, p) + if err != nil { + return nil, err + } + + priv := &PrivateKey{pub: PublicKey{p: p}} + priv.k = K + priv.pub.tr = tr + A := priv.pub.a[:k*l] + computeMatrixA(A, ρ[:], p) + for r := range l { + priv.s1[r] = ntt(s1[r]) + } + for r := range k { + priv.s2[r] = ntt(s2[r]) + } + for r := range k { + priv.t0[r] = ntt(t0[r]) + } + + // We need to put something in priv.seed, and putting random bytes feels + // safer than putting anything predictable. + drbg.Read(priv.seed[:]) + + // Making this format *even more* annoying, we need to recompute t1 from ρ, + // s1, and s2 if we want to generate the public key. This is essentially as + // much work as regenerating everything from seed. + // + // You might also notice that the semi-expanded format also stores t0 and a + // hash of the public key, though. How are we supposed to check they are + // consistent without regenerating the public key? Do we even need to check? + // Who knows! FIPS 204 says + // + // > Note that there exist malformed inputs that can cause skDecode to + // > return values that are not in the correct range. Hence, skDecode + // > should only be run on inputs that come from trusted sources. + // + // so it sounds like it doesn't even want us to check the coefficients are + // within bounds, but especially if using this format for key exchange, that + // sounds like a bad idea. So we check everything. + + t1 := make([][n]uint16, k, maxK) + for i := range k { + tHat := priv.s2[i] + for j := range l { + tHat = polyAdd(tHat, nttMul(A[i*l+j], priv.s1[j])) + } + t := inverseNTT(tHat) + for j := range n { + r1, r0 := power2Round(t[j]) + t1[i][j] = r1 + if r0 != t0[i][j] { + return nil, errors.New("mldsa: semi-expanded private key inconsistent with t0") + } + } + } + + pk := pkEncode(priv.pub.raw[:0], ρ[:], t1, p) + if computePublicKeyHash(pk) != tr { + return nil, errors.New("mldsa: semi-expanded private key inconsistent with public key hash") + } + computeT1Hat(priv.pub.t1[:k], t1) // NTT(t₁ ⋅ 2ᵈ) + + return priv, nil +} + +func TestingOnlyPrivateKeySemiExpandedBytes(priv *PrivateKey) []byte { + k, l, η := priv.pub.p.k, priv.pub.p.l, priv.pub.p.η + sk := make([]byte, 0, semiExpandedPrivKeySize(priv.pub.p)) + sk = append(sk, priv.pub.raw[:32]...) // ρ + sk = append(sk, priv.k[:]...) // K + sk = append(sk, priv.pub.tr[:]...) // tr + for i := range l { + sk = bitPackSlow(sk, inverseNTT(priv.s1[i]), η, η) + } + for i := range k { + sk = bitPackSlow(sk, inverseNTT(priv.s2[i]), η, η) + } + const bound = 1 << (13 - 1) // 2^(d-1) + for i := range k { + sk = bitPackSlow(sk, inverseNTT(priv.t0[i]), bound-1, bound) + } + return sk +} + +func skDecode(sk []byte, p parameters) (ρ, K [32]byte, tr [64]byte, s1, s2, t0 []ringElement, err error) { + k, l, η := p.k, p.l, p.η + if len(sk) != semiExpandedPrivKeySize(p) { + err = errors.New("mldsa: invalid semi-expanded private key size") + return + } + copy(ρ[:], sk[:32]) + sk = sk[32:] + copy(K[:], sk[:32]) + sk = sk[32:] + copy(tr[:], sk[:64]) + sk = sk[64:] + + s1 = make([]ringElement, l) + for i := range l { + length := n * bits.Len(uint(η)*2) / 8 + s1[i], err = bitUnpackSlow(sk[:length], η, η) + if err != nil { + return + } + sk = sk[length:] + } + + s2 = make([]ringElement, k) + for i := range k { + length := n * bits.Len(uint(η)*2) / 8 + s2[i], err = bitUnpackSlow(sk[:length], η, η) + if err != nil { + return + } + sk = sk[length:] + } + + const bound = 1 << (13 - 1) // 2^(d-1) + t0 = make([]ringElement, k) + for i := range k { + length := n * 13 / 8 + t0[i], err = bitUnpackSlow(sk[:length], bound-1, bound) + if err != nil { + return + } + sk = sk[length:] + } + + return +} + +func bitPackSlow(buf []byte, r ringElement, a, b int) []byte { + bitlen := bits.Len(uint(a + b)) + if bitlen <= 0 || bitlen > 16 { + panic("mldsa: internal error: invalid bitlen") + } + out, v := sliceForAppend(buf, n*bitlen/8) + var acc uint32 + var accBits uint + for i := range r { + w := int32(b) - fieldCenteredMod(r[i]) + acc |= uint32(w) << accBits + accBits += uint(bitlen) + for accBits >= 8 { + v[0] = byte(acc) + v = v[1:] + acc >>= 8 + accBits -= 8 + } + } + if accBits > 0 { + v[0] = byte(acc) + } + return out +} + +func bitUnpackSlow(v []byte, a, b int) (ringElement, error) { + bitlen := bits.Len(uint(a + b)) + if bitlen <= 0 || bitlen > 16 { + panic("mldsa: internal error: invalid bitlen") + } + if len(v) != n*bitlen/8 { + return ringElement{}, errors.New("mldsa: invalid input length for bitUnpackSlow") + } + + mask := uint32((1 << bitlen) - 1) + maxValue := uint32(a + b) + + var r ringElement + var acc uint32 + var accBits uint + vIdx := 0 + + for i := range r { + for accBits < uint(bitlen) { + if vIdx < len(v) { + acc |= uint32(v[vIdx]) << accBits + vIdx++ + accBits += 8 + } + } + w := acc & mask + if w > maxValue { + return ringElement{}, errors.New("mldsa: coefficient out of range") + } + r[i] = fieldSubToMontgomery(uint32(b), w) + acc >>= bitlen + accBits -= uint(bitlen) + } + + return r, nil +} diff --git a/mldsa/internal/fips140/mldsa/stubs.go b/mldsa/internal/fips140/mldsa/stubs.go new file mode 100644 index 0000000..89e501a --- /dev/null +++ b/mldsa/internal/fips140/mldsa/stubs.go @@ -0,0 +1,10 @@ +package mldsa + +// Stubs for functions that are implemented in the upstream +// crypto/internal/fips140/mldsa package, to minimize the diff. + +import "sync" + +func fipsPCT(priv *PrivateKey) {} + +var fipsSelfTest = sync.OnceFunc(func() {}) diff --git a/mldsa/internal/fips140/stubs.go b/mldsa/internal/fips140/stubs.go new file mode 100644 index 0000000..32d6bb9 --- /dev/null +++ b/mldsa/internal/fips140/stubs.go @@ -0,0 +1,5 @@ +package fips140 + +const Enabled = false + +func RecordApproved() {} diff --git a/mldsa/internal/fips140test/mldsa_test.go b/mldsa/internal/fips140test/mldsa_test.go new file mode 100644 index 0000000..db87625 --- /dev/null +++ b/mldsa/internal/fips140test/mldsa_test.go @@ -0,0 +1,729 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fipstest + +import ( + "crypto/sha3" + "encoding/hex" + "flag" + "math/rand" + "testing" + + "github.com/go-webauthn/x/mldsa/internal/cryptotest" + "github.com/go-webauthn/x/mldsa/internal/fips140" + . "github.com/go-webauthn/x/mldsa/internal/fips140/mldsa" +) + +var sixtyMillionFlag = flag.Bool("60million", false, "run 60M-iterations accumulated test") + +// TestMLDSAAccumulated accumulates 10k (or 100, or 60M) random vectors and checks +// the hash of the result, to avoid checking in megabytes of test vectors. +// +// 60M in particular is enough to give a 99.9% chance of hitting every value in +// the base field. +// +// 1-((q-1)/q)^60000000 ~= 0.9992 +// +// If setting -60million, remember to also set -timeout 0. +func TestMLDSAAccumulated(t *testing.T) { + t.Run("ML-DSA-44/100", func(t *testing.T) { + testMLDSAAccumulated(t, NewPrivateKey44, NewPublicKey44, 100, + "d51148e1f9f4fa1a723a6cf42e25f2a99eb5c1b378b3d2dbbd561b1203beeae4") + }) + t.Run("ML-DSA-65/100", func(t *testing.T) { + testMLDSAAccumulated(t, NewPrivateKey65, NewPublicKey65, 100, + "8358a1843220194417cadbc2651295cd8fc65125b5a5c1a239a16dc8b57ca199") + }) + t.Run("ML-DSA-87/100", func(t *testing.T) { + testMLDSAAccumulated(t, NewPrivateKey87, NewPublicKey87, 100, + "8c3ad714777622b8f21ce31bb35f71394f23bc0fcf3c78ace5d608990f3b061b") + }) + if !testing.Short() { + t.Run("ML-DSA-44/10k", func(t *testing.T) { + t.Parallel() + testMLDSAAccumulated(t, NewPrivateKey44, NewPublicKey44, 10000, + "e7fd21f6a59bcba60d65adc44404bb29a7c00e5d8d3ec06a732c00a306a7d143") + }) + t.Run("ML-DSA-65/10k", func(t *testing.T) { + t.Parallel() + testMLDSAAccumulated(t, NewPrivateKey65, NewPublicKey65, 10000, + "5ff5e196f0b830c3b10a9eb5358e7c98a3a20136cb677f3ae3b90175c3ace329") + }) + t.Run("ML-DSA-87/10k", func(t *testing.T) { + t.Parallel() + testMLDSAAccumulated(t, NewPrivateKey87, NewPublicKey87, 10000, + "80a8cf39317f7d0be0e24972c51ac152bd2a3e09bc0c32ce29dd82c4e7385e60") + }) + } + if *sixtyMillionFlag { + t.Run("ML-DSA-44/60M", func(t *testing.T) { + t.Parallel() + testMLDSAAccumulated(t, NewPrivateKey44, NewPublicKey44, 60000000, + "080b48049257f5cd30dee17d6aa393d6c42fe52a29099df84a460ebaf4b02330") + }) + t.Run("ML-DSA-65/60M", func(t *testing.T) { + t.Parallel() + testMLDSAAccumulated(t, NewPrivateKey65, NewPublicKey65, 60000000, + "0af0165db2b180f7a83dbecad1ccb758b9c2d834b7f801fc49dd572a9d4b1e83") + }) + t.Run("ML-DSA-87/60M", func(t *testing.T) { + t.Parallel() + testMLDSAAccumulated(t, NewPrivateKey87, NewPublicKey87, 60000000, + "011166e9d5032c9bdc5c9bbb5dbb6c86df1c3d9bf3570b65ebae942dd9830057") + }) + } +} + +func testMLDSAAccumulated(t *testing.T, newPrivateKey func([]byte) (*PrivateKey, error), newPublicKey func([]byte) (*PublicKey, error), n int, expected string) { + s := sha3.NewSHAKE128() + o := sha3.NewSHAKE128() + seed := make([]byte, PrivateKeySize) + msg := make([]byte, 0) + + for i := 0; i < n; i++ { + s.Read(seed) + dk, err := newPrivateKey(seed) + if err != nil { + t.Fatalf("NewPrivateKey: %v", err) + } + pk := dk.PublicKey().Bytes() + o.Write(pk) + sig, err := SignDeterministic(dk, msg, "") + if err != nil { + t.Fatalf("SignDeterministic: %v", err) + } + o.Write(sig) + pub, err := newPublicKey(pk) + if err != nil { + t.Fatalf("NewPublicKey: %v", err) + } + if *pub != *dk.PublicKey() { + t.Fatalf("public key mismatch") + } + if err := Verify(dk.PublicKey(), msg, sig, ""); err != nil { + t.Fatalf("Verify: %v", err) + } + } + + sum := make([]byte, 32) + o.Read(sum) + got := hex.EncodeToString(sum) + if got != expected { + t.Errorf("got %s, expected %s", got, expected) + } +} + +func TestMLDSAGenerateKey(t *testing.T) { + t.Run("ML-DSA-44", func(t *testing.T) { + testMLDSAGenerateKey(t, GenerateKey44, NewPrivateKey44) + }) + t.Run("ML-DSA-65", func(t *testing.T) { + testMLDSAGenerateKey(t, GenerateKey65, NewPrivateKey65) + }) + t.Run("ML-DSA-87", func(t *testing.T) { + testMLDSAGenerateKey(t, GenerateKey87, NewPrivateKey87) + }) +} + +func testMLDSAGenerateKey(t *testing.T, generateKey func() *PrivateKey, newPrivateKey func([]byte) (*PrivateKey, error)) { + k1 := generateKey() + k2 := generateKey() + if k1.Equal(k2) { + t.Errorf("two generated keys are equal") + } + k1x, err := newPrivateKey(k1.Bytes()) + if err != nil { + t.Fatalf("NewPrivateKey: %v", err) + } + if !k1.Equal(k1x) { + t.Errorf("generated key and re-parsed key are not equal") + } +} + +func TestMLDSAAllocations(t *testing.T) { + // We allocate the PrivateKey (k and kk) and PublicKey (pk) structs and the + // public key (pkBytes) and signature (sig) byte slices on the heap. They + // are all large and for the byte slices variable-length. Still, check we + // are not slipping more allocations in. + var expected float64 = 5 + if fips140.Enabled { + // The PCT does a sign/verify cycle, which allocates a signature slice. + expected += 1 + } + cryptotest.SkipTestAllocations(t) + if allocs := testing.AllocsPerRun(100, func() { + k := GenerateKey44() + seed := k.Bytes() + kk, err := NewPrivateKey44(seed) + if err != nil { + t.Fatalf("NewPrivateKey44: %v", err) + } + if !k.Equal(kk) { + t.Fatalf("keys not equal") + } + pkBytes := k.PublicKey().Bytes() + pk, err := NewPublicKey44(pkBytes) + if err != nil { + t.Fatalf("NewPublicKey44: %v", err) + } + message := []byte("Hello, world!") + context := "test" + sig, err := Sign(k, message, context) + if err != nil { + t.Fatalf("Sign: %v", err) + } + if err := Verify(pk, message, sig, context); err != nil { + t.Fatalf("Verify: %v", err) + } + }); allocs > expected { + t.Errorf("expected %0.0f allocations, got %0.1f", expected, allocs) + } +} + +func BenchmarkMLDSASign(b *testing.B) { + // Signing works by rejection sampling, which introduces massive variance in + // individual signing times. To get stable but correct results, we benchmark + // a series of representative operations, engineered to have the same + // distribution of rejection counts and reasons as the average case. See also + // https://words.filippo.io/rsa-keygen-bench/ for a similar approach. + b.Run("ML-DSA-44", func(b *testing.B) { + benchmarkMLDSASign(b, NewPrivateKey44, benchmarkMessagesMLDSA44) + }) + b.Run("ML-DSA-65", func(b *testing.B) { + benchmarkMLDSASign(b, NewPrivateKey65, benchmarkMessagesMLDSA65) + }) + b.Run("ML-DSA-87", func(b *testing.B) { + benchmarkMLDSASign(b, NewPrivateKey87, benchmarkMessagesMLDSA87) + }) +} + +func benchmarkMLDSASign(b *testing.B, newPrivateKey func([]byte) (*PrivateKey, error), messages []string) { + seed := make([]byte, 32) + priv, err := newPrivateKey(seed) + if err != nil { + b.Fatalf("NewPrivateKey: %v", err) + } + rand.Shuffle(len(messages), func(i, j int) { + messages[i], messages[j] = messages[j], messages[i] + }) + i := 0 + for b.Loop() { + msg := messages[i] + if i++; i >= len(messages) { + i = 0 + } + SignDeterministic(priv, []byte(msg), "") + } +} + +// BenchmarkMLDSAVerify runs both public key parsing and signature verification, +// since pre-computation can be easily moved between the two, but in practice +// most uses of verification are for fresh public keys (unlike signing). +func BenchmarkMLDSAVerify(b *testing.B) { + b.Run("ML-DSA-44", func(b *testing.B) { + benchmarkMLDSAVerify(b, GenerateKey44, NewPublicKey44) + }) + b.Run("ML-DSA-65", func(b *testing.B) { + benchmarkMLDSAVerify(b, GenerateKey65, NewPublicKey65) + }) + b.Run("ML-DSA-87", func(b *testing.B) { + benchmarkMLDSAVerify(b, GenerateKey87, NewPublicKey87) + }) +} + +func benchmarkMLDSAVerify(b *testing.B, generateKey func() *PrivateKey, newPublicKey func([]byte) (*PublicKey, error)) { + priv := generateKey() + msg := make([]byte, 128) + sig, err := SignDeterministic(priv, msg, "context") + if err != nil { + b.Fatalf("SignDeterministic: %v", err) + } + pub := priv.PublicKey().Bytes() + for b.Loop() { + pk, err := newPublicKey(pub) + if err != nil { + b.Fatalf("NewPublicKey: %v", err) + } + if err := Verify(pk, msg, sig, "context"); err != nil { + b.Fatalf("Verify: %v", err) + } + } +} + +func BenchmarkMLDSAKeygen(b *testing.B) { + b.Run("ML-DSA-44", func(b *testing.B) { + for b.Loop() { + NewPrivateKey44(make([]byte, 32)) + } + }) + b.Run("ML-DSA-65", func(b *testing.B) { + for b.Loop() { + NewPrivateKey65(make([]byte, 32)) + } + }) + b.Run("ML-DSA-87", func(b *testing.B) { + for b.Loop() { + NewPrivateKey87(make([]byte, 32)) + } + }) +} + +var benchmarkMessagesMLDSA44 = []string{ + "BUS7IAZWYOZ4JHJQYDWRTJL4V7", + "MK5HFFNP4TB5S6FM4KUFZSIXPD", + "DBFETUV4O56J57FXTXTIVCDIAR", + "I4FCMZ7UNLYAE2VVPKTE5ETXKL", + "56U76XRPOVFX3AU7MB2JHAP6JX", + "3ER6UPKIIDGCXLGLPU7KI3ODTN", + "JPQDX2IL3W5CYAFRZ4XUJOHQ3G", + "6AJOEI33Z3MLEBVC2Q67AYWK5L", + "WE3U36HYOPJ72RN3C74F6IOTTJ", + "NMPF5I3B2BKQG5RK26LMPQECCX", + "JRGAN2FA6IY7ESFGZ7PVI2RGWA", + "UIKLF6KNSIUHIIVNRKNUFRNR4W", + "HA252APFYUWHSZZFKP7CWGIBRY", + "JFY774TXRITQ6CIR56P2ZOTOL6", + "ZASYLW5Y3RAOC5NDZ2NCH5A4UY", + "42X4JXNPXMFRCFAE5AKR7XTFO7", + "YAHQUWUH534MUI2TYEKQR7VR3A", + "HBP7FGEXGSOZ5HNOVRGXZJU2KG", + "HG4O7DCRMYMQXASFLMYQ6NMIXK", + "2KPQMDZKS65CLJU4DHTMVV5WI3", + "G6YSUTEX4HHL44ISK2JVVK45BV", + "PUJGPEQUBQM3IK2EXDQFJ2WGBG", + "PNS6HMQAWA3RORSMSNEUAINMIR", + "L35MZS4XYIJK453OFXCZG4WHIK", + "CRY54YZMFRF6JTB3FPNNBWPUOG", + "Y25TSZBWGU4HJCRMWZHAWXQ2DN", + "23W64TW3AKZPKCM4HMKEHFI6VQ", + "PWQAOZ24B4VLNEQR4XKN7LZHDI", + "YINPDR3ZSAKPPXP6J6VAXHIPYO", + "JDBB52ZRAB3PYBPNE7P4COY5PJ", + "4DYU52LQLVG3LTREOTLBCJK3XC", + "AB45MV6RKUGPCW4EUK7DX23MJX", + "HEJSITE5K7J6YJ74OEATVTCERV", + "ZKI5QCFCGM26UK7F5KYTENXKD2", + "VH5G3ZLF5XC22QAEJ6JDGOBE5Y", + "HYGXFHH3JW5SENG26MXLL54IGV", + "MJUCRL36JZ757UYHBFPCJBPZRH", + "IBH3T6NAVLCJQBYSVHAQFUITYA", + "VMWCS7JMIMFQB6TPRAMOUXIKWD", + "SXRPGPNNW2MMBKQS3HJURIQ3XV", + "YPPYMJZW6WYXPSCZIPI57NTP5L", + "N3SH6DUH6UOPU7YMQ6BJJEQSPI", + "Q243DGA6VC6CW66FFUAB5V3VLB", + "OUUBXEU4NJBRN5XZJ7YQUPIZLA", + "H5TWHVGC7FXG6MCKJQURD3RNWG", + "OONG2ZZ7H3P5BREEEURNJHBBQG", + "HWROSSRTBCQOAIQAY5S4EQG4FX", + "AJW6PW62JQNU72VKGIQMPBX64C", + "OXECVUVAWBBBXGGQGQBTYVEP4S", + "M5XN6V2LQJDEIN3G4Z6WJO6AVT", + "NHGJUX3WGRTEIRPFWC2I467ST4", + "SEOADTJDKAYYLDSC4VAES2CRDJ", + "J5AT674S577ZFGEURNIAGYOHKW", + "VJQVNMGHG4ITFX2XSPSDEWVZWD", + "ZWY3KJPXTAVWWVHNAJDUXZ52TG", + "HY46PBUGP4EMH34C6Q56MO7CJP", + "MQTUO7CF6R6CRJPVV6F673M6VW", + "35Z2Z5KV2RBJPQ7OZ24ZJE6BKR", + "OVUEVXBLCU2BBY25QP5WJACDIX", + "LNJX7PCLYL35WYJBW6CTXENPUU", + "IH7E766LCENOQ5ZKZVCMLEPACU", + "T2HZFGDDSFQ6YADB52NIFLBFEV", + "RHQUJMN4MB5SYY4FP4ARZH52QJ", + "W7GZC5ZM63UF2EJ7OC4WJM3OTH", + "T2NHNFVOMICY33AQZSR53HXFQ6", + "7ZVB4Y4K4Y2VAM5NC7HHAJNZIB", + "UX2I4VF62XJGP2XTNN6LDKXTOH", + "HJAMJR5RQTQW7JMW7ZLPRBZE7E", + "HKWSKX7MB5346PHYNWNBAYDSYK", + "BVWSB75HFLLE45MWA6EPHPTCFR", + "YDH2J6NMM7UINHGUOPIUI7PSSR", + "SYQPZLK52HMUAQFMVHGRJYKBEY", + "7AA6UQFGSPBGNUDPLWXSGNKKPP", + "AYXRJGRWZ5S3QOEDVWYHHCICHV", + "KFJYAWO7IATSBCSTDUAA5EPFAN", + "3JABTLB6T2ICHGVT3HXZZ3OAIT", + "WCM3IBOCQJ36WSG627CCNK3QA7", + "5FB5H3BZN2J4RGR2DUW7M37NKZ", + "VKDDAD3BVOMPSNEDGIRHKX5S6R", + "LFH5HVUR726OSFD3YVYM3ZHEIH", + "Y4ETQB2KZVFB4M7SALLCTHX2FB", + "E6SAU3C25MO2WBBVBKCKP2N4ZE", + "3JA54Q3NEKURB5EAPL2FOFIESD", + "FZPBW7BIQIW3FTKQD4TLKNWLMD", + "LY5W6XFA2ZRI53FTUJYGWZ5RX6", + "QID236JY3ICR55O5YRED33O7YT", + "HDRU3L6MFEBCBQFNLF5IRPMOAL", + "232ANKJBDBG4TSKQ7GJMWTHT23", + "CDWE3CELZM5AOJGYEFHMUNSP5O", + "7LNJRBOKN6W7RXUU34MDJ2SNKL", + "S3IZOADTW2A6E5IGRO5WKX7FVH", + "ZAISTLXC55EBMTN6KZ6QX5S7OS", + "4Z5ZIVCMFR2PY2PY4Z47T4YPYA", + "NE36L53Z6AMYQU7Q5REFUF76MK", + "WND5UP5M6KWPBRFP5WIWTOWV3I", + "7OC54DLFWMADJEMKEJ3Y2FMMZS", + "BWJVZHGEN43ULNIOZCPZOB64HG", + "VDFPQSR7RE54A75GT4JDZY5JK2", + "HFCD5EPBZBSVMXIDA47DZ6MRD6", + "RNBVFIUUJUM7EHRE3VNWSTORGO", + "VO5NLQJBR22CRRYUETGTU6JLMR", + "RZOMNFHBTL6HMGWH4PEEDASK7U", + "QL73UBTOLK5O2TW43YWAIKS6T3", + "NE3QVSMWS5G3W5C3BMKTJNMI2L", + "YHI6EYQ4GZMB2QPGHPUG2ZUOEL", + "6MBATW7MFNRUQBFD3GM35B7YPM", + "AIYRY6P5T4XU44CGVPEV6W43FR", + "MIAQ2FHXMAPY5NXSS45VRDPRMG", + "2SNLHQYKK2K6NSWOF6KPGZ3CPC", + "RVBHIQO5LH77ZWEAO3SVL72M2V", + "XXTGJCJNRSNLE7ARAH2UU6LVKR", + "DQMGILY5IDMWN5OYQYYXH26ZGR", + "627VTXXMM455KMTFNUUTKNFXPY", + "HC7IBFGLZCWGUR4K7REPMPW6W4", + "CHL6JRQUS7D4NML3PFT37PPZAA", + "Y767HXJAGJ75KE3JLO4DTLQIXC", + "NTIODXI5I7TF2KXXWXOAYGT7G4", + "PKZYEK2WAI4D4HEYYZH6H5IOMP", + "FG6J6G7HZDEDF4JQBQOTC7RQGZ", + "3VHM2VZU77Y25E3UUYZJLB2QLA", + "WRZQJQW7ARH4DXYHVLCJ4HRTTB", + "LQXKV5HD2AZHENSJ2VFLJ5YU5L", + "MF6Q4OA2EN6TG6BUDK7RWCQNPU", + "3USKYKPC5CB3EC4ZRMZVE3R2UO", + "3WICO2GVS3IRBFUHNDLNKWVP7N", + "P6ZR2UZZOVUZKT4KUS5WICW5XE", + "PYPZUU76RYVOUZGUUX33HLDKYA", + "2FTSURHV34VYTVIUU7W6V5C3NK", + "YABDYMGXS2MD2CYF3S4ALG4FLG", + "MHIBDH25RRPWV3P4VAWT6SAX3I", + "OINSMWJQ2UTOOKZ3X6ICXXBQR7", + "PFTQS7JNU2Q3Q6L4CGBXVLOYNE", + "A4MZ7CCVYQUDJ2AFHNXBBQ3D24", + "CPUB5R3ORTCMSMCLUQURE6AN5O", + "NF5E7U3DFTXWFFXXHUXTEP4VZQ", + "AWB5WDFERWSSJG53YGJMDORQKR", + "U5JQUILKD6SEL6LXAMNFZP6VSW", + "M45NLOAFLO74EJKG5EXNET6J5Y", + "P2KTEUMZ5DZZMYSPOHDR2WJXAN", + "KVO7AXZNFBUBPYLOTZQQ42TFNS", + "WGJJ7SAEV6SBBWWYS4BTLD63WM", + "Y6GURVDV4ESRBPWSTV25T4PE4K", + "ESK7MPFPUZ5ZAQ52RP4SQIYCCC", + "623M3CIABZ3RANERQ2IREXAVYO", + "OQ4CQCFO42RS4BMMSGSDLUTOQO", + "AMFHRDVGM6G2TIR3TKIFGFSDVM", + "7VVSGGCVC53PLOYG7YHPFUJM5X", + "Z3HMESVL7EZUSZNZ33WXEBHA2N", + "AWWVRQD5W7IBSQPS26XOJVDV5H", + "OQBZ5ZST3U3NZYHSIWRNROIG6L", + "II573BW7DJLBYJSPSYIABQWDZD", + "MOKXOQFOCUCLQQH4UKH2DPE7VN", + "XR54NGUOU6BBUUTINNWBPJ35HX", + "DNK36COZGFXI6DY7WLCNUETIRT", + "R5M2PV7E3EHEM3TLGRCL3HSFMC", + "ITKENZQYDQMZFCUPOT7VF3BMU7", + "5GDCB74PPPHEP5N5G3DVRCYT7R", + "ZMKXVRPLI5PY5BDVEPOA3NQZGN", + "GBLIALWTHTUDTOMDERQFVB77CS", + "VKRTTXUTFOK4PJAQQZCCT7TV3T", + "ZJBUJJ4SW62BXOID3XO2W2M2PF", + "SKWT5T6QJTCD3FCINIK22KMVBJ", + "EHINNU6L33HRLOOJ3A2XFJSYQL", + "N4HRQJEFPAT5SU3YPO74WSMQIR", + "TGPTZ3ENMFWB5CZKJFR5WHIRI4", + "O4HNFTAUJJ2LZPQXPXRAXOVABA", + "4JVB5STP2YG5GYOXDWIF4KCKFB", + "MY554X3YZHBECLHNNZ7A3SPJTU", + "ASCJMAH7VCQAD2QJSWXPSVSM3H", + "NBNGL5DZ623KCG2JNZFGZMZ7KD", + "KGMZSW35AEQOJ6FA7IR7BHZI52", + "Q7QUHHS4OJFMJ4I3FY6TDKSMZQ", + "MZAE7TOEXAS76T7KIC73FEYRU4", + "2BVESR3REAWADCGYOYM7T646RG", + "EK3L2ORP4LT3HU3EMXDSQWFOKJ", + "3X4A6VMGMIDLVK72FZSDHSERWY", + "I3UHWI6M6HQFRBSQ6W2SABUNUP", + "REKPXW4DIB4MTKMPHN3RBVHVME", + "W37FNFZE35NX65Z7CVQ7L5U4L5", + "4AGYK6U2KP6RAOADCBUDDCBECV", + "IXM4SFQUDW2NOTXZIPWTNGET3F", + "6YE4G3VELF27MN3Z5B4VIQ3XYK", + "LPOZCPZAG3MD47MIWGR4FIOCDH", + "WGREKUL2LD7C7SYGKH7APIY2A6", + "WWW277FKTKUXQMP4BECSRHLWJI", + "UYE4IQPMSTXVQG7EJALKWWEGDN", + "TIV2L5Z6K7SNGNUVWSNKTAF4UE", + "I3FQOAW3PINUK26P62HCX657FO", +} + +var benchmarkMessagesMLDSA65 = []string{ + "NDGEUBUDWGRJJ3A4UNZZQOEKNL", + "ACGYQUXN4POOFUENCLNCIPHFAZ", + "Z3XETEYKROVJH7SIHOIAYCTO42", + "DXWCVCEFULV7XHRWHJWSEXWES7", + "BCR2D5PNLGFYX6B3QFQFV23JZP", + "2DVP5HNG54ES64QK4D37PWUYTJ", + "UJM4ADPJLURAIQH4XA6QYUGNJ6", + "B5WRCIPK5IVZW52R6TJOKNPKZH", + "7QNL6JTSP62IGX6RCM2NHRMTKK", + "EJSZQYLM7G7AJCGIEVBV2UW7NN", + "UFNA2NKJ3QFWNHHL5CXZ4R5H46", + "QZAXRTT3E4DOGVTJCOTBG3WXQV", + "KH2ETOYZO5UHIHIKATWJMUVG27", + "V5HVVQTOWRXZ2PB4XWXSEKXUN5", + "5LA7NAFI2LESMH533XY45QVCQW", + "SMF4TWPTMJA2Z4F4OVETTLVRAY", + "FWZ5OJAFMLTQRREPYF4VDRPPGI", + "OK3QMNO3OZSKSR6Q4BFVOVRWTH", + "NQOVN6F6AOBOEGMJTVMF67KTIJ", + "CCLC4Y6YT3AQ3HGT2QNSYAUGNV", + "CAZJHCHBUYQ6OKZ7DMWMDDLIZQ", + "LVW5XDTHPKOW5D452SYD7AFO6Q", + "EYA6O6FTYPC6TRKZPRPX5N2KQ4", + "Z6SGAEZ2SAAZHPQO7GL7CUMBAG", + "FKUCKW6JQVF4WQYXUSXYZQMAVY", + "LN2KDF4DANPE4SC4GKJ4BES3IZ", + "AVCRTWB6ALOQHY34XI7NTMP2JH", + "A5WHIS6CBWPCYIEC6N2MBAOEZ6", + "JC2BH476BXUQFIDA6UCR5V4G4F", + "NU6XH6VLSSFHVSRZCYXPFYKYCD", + "GSUXVZBDDYSZYFGXNP6AZW3PTC", + "XJPRNJ26XP4MIYH2Q7M7MPZ73M", + "INUTUP3IRFWIIT23DNFTIYKCFY", + "T4KH7HKLEYGXHBIRFGFCRUZCC4", + "GGQX4JFVWZHE5Y73YTLMSSOXNS", + "BUA4Q3TQZGLVHMMJU62GQOSHLV", + "WXW3SJXLSZO2MYF4YFIMXL2IQP", + "Q32XBVVGFQTSXAIDJE6XSEPRZG", + "6TEXT6SA7INRCTDSCSVZJEQ2YG", + "ZBN4UL43C3SJIG4HYR236PXCVS", + "TVWPLLC7NROBREWOM75VA3XCR3", + "CCDGL2FURLBABQ4IJBYCB75JFR", + "XBZGCOVTZHCPAARBTMAKPIE6GJ", + "TPRAENJ7I54XRIVH6LL6FDIA3I", + "RKOM3PHFILPIIQZL4ILQWGRYWI", + "CEEZIZ2WUXHQQFATYYGQ3ZDBTI", + "SLKOVAP6WLIVJBVU7VZG3ZGEOW", + "TWMCLJJSWEEQQPQGGDKEJ5SU2R", + "IFMUXXCD2LC7IGQLZ2QEK5UOQ2", + "C7IWFEBHW2CXN4XBJS7VLWH3VK", + "7KJYUEW3F264727TM4LE6RMGDO", + "BPG2XAPBMBTA4VMPUM7IZVZPK3", + "Y5X577BWRZNPLNUHJVSKGMUXYB", + "ZCKMKM23E4IUPTNQDFN2LTLZVX", + "4RKK223JNBDAP4G5DOAHHZ3VNO", + "5UZ3TQZHZT22ISTB4WJEVO6MC4", + "YMVS4HFSJ32CRZRL23PXZUEJFJ", + "UQEUJUTPSZLZARNBXWMCTMHPFF", + "CZAAZ5WK7EIPMW7NA3EZNNBF45", + "227PBHH23WM7F2QLEZSPFYXVW4", + "YUYS2J5CRFXZ4J4KJT2ZKIZVW3", + "MFLHZJOZV44SN4AH6OJ3QZWM2O", + "H2B3CRBCXYN7QWDGYUPHQZP23A", + "T4L6YWQUQ3CTACENAJ5WUXZWFH", + "N723H6MUGPZSRZ72C635OD4BP7", + "NI4TUMVA6LQPQV2TXPN4QOIGBZ", + "CQI3S4LSTQASSJJVZXEFPOVW7K", + "ANPY4HJ64LLSB3GK2R4C6WDBS3", + "RGWQCZKQLMT5FZRDE4B3VMASVK", + "Q3WCCF2HA3CA4WWRJBMGBW7WI7", + "2AKJRXFHXLUQPOXPTLSZN5PW4A", + "IJWOOTI4N7RWXJIHAPXN6KEWEN", + "4D53T6N6ATOVTD4LKSTAAWBJMU", + "B4G5HDD6RITG6NIH6FXCRZDYZM", + "TJCDFKMRUY2OG6KRSMNVCGQFUP", + "PB33IHQKALAY6H6GVBVLI6ZRXK", + "SCCWGW2J5S4WL4FTTMQ435F6DB", + "ZVJH2HSMTLHGXMGPMXLJCKCLLE", + "62LG37U6JXR77YRZQQCDSBHVCS", + "BU4CBWOXQ352TEOKIXO245ID4O", + "UEZOH7KEIODSEVRUF6GMWGA2RB", + "IPJWROME4GM66CGLUWP5BJ4SX6", + "355GDC7TG64AZJ7IJX6K62KZCZ", + "AHTFKX3V7XUB3EWOMQVCGZYGUE", + "N4RV2GKXJ4SPHHJ52Z7K5EGLER", + "ZY7V7NE5F66XHDHWM6YNFEWZA6", + "DIKFO5KAVT4WAP7BOEFM56ZUSR", + "4TDFOFKDAPIOM3MU5GD7NPXNWQ", + "AD7YZO756HDK6YWFILAKW3JWA7", + "NUA53JS2ZK2BGHH3A7BJTJZYW7", + "QLCNC3AQNKLRMSYR62WQSQP5VI", + "SJ7OBS7ZYXSGXOYXPE5KW2XKN6", + "44HBMOGMIMJS63CEXQU7FCXE2E", + "KCK3J7ZL6QF4SLHHSWTJURK7PG", + "HLH4CLUGBSOOBSS3BPO62N5MC3", + "3FNS4GITO6OEUBAVDDXK4WOBTD", + "IAC3K3I4AQGY3G6UHG7PL2N6TE", + "KUKLNH74POJI5DYAEWUD7RABTQ", + "ETM6N7VU3GBSQ7P5MCD6UF3E3S", + "IZITM5NYBGJZLSI3BI4VEMW43U", + "46OPQU4LL6N3Z2U7KYPKUMBAGI", + "EV7YZ5DMAV7VKYJQUFSRD37GPP", + "AV7W2PGYDJIAKLFVEBL6BXQSGC", + "M2FOX5QZEZKV4QXKPI5XUZDHEM", + "R4IFPLVMOVYCHRTR6LXAUGP3LL", + "JGH6XJUMP4DRVAM27P2JNOKXVO", + "D2XN3ZLLU6VFPMDYM7NBHSQEOI", + "2PO3BYENOMQK6SHQDCFSRPJQI3", + "IBVQ7U3QEUC6PQRE4PV53JTZTK", + "ZBCOX4P7NG2IXXFB2R43MG2SLV", + "5NJDPQVVDO7ADNZ2CV7L6QBNGZ", + "V7ASFIIYUMXFGW4B7ZM6LOGUTE", + "PX5IJZ7W2LUPKM6YN4PMZ43ZLM", + "AYK7SZ23DHC7Q56MWAJXBG76LB", + "UYCAPXJM4HNGKLIDSZ4NCEDJLN", + "UWMDZ3C2ODLACKGJPGETNQ3TA4", + "Q6OI6R3WYYJ4CCZCDJBQMCRCZR", + "LCMJHLP7354APCEGPKE7HHWTWB", + "N7T7ZKOYPAMEYTTDOWZNCN6PRD", + "UZADPU4UNHAF7L7LQDMTKA2EQH", + "DC2OEPQDECVLRVNNCS6BMH4CRA", + "37IZ427XHUMZ66EJ62U2YEZDAC", + "6BCZDQZDPZLS5OGESKNUBPSSFV", + "ST2LEMJ4OLQ32TJTLH2WCWT4WA", + "GA2TL4SFLEW4G2B5PQMIKJT5XG", + "L7PPBIET26EH7LQTLEFC4I4EIA", + "6YSM7MC2W4DEV6ULAHMX27LH56", + "QL26Z5KZ4YRRG2BXXGDRRLV357", + "677TWRAJ5NSNHCE243POQPEG7K", + "66MEBQJLGAGVXDX3KZ2YFTTVJM", + "6D4VUWAQD6R65ICSDLFAATC67V", + "7GXLD5CNU3TDUQSSW42SHL7B5D", + "RQETUMEBG2ZM2NF2EZAQHGHWWE", + "DCRX5ANWDMXZFIDVAXYLQZYMRN", + "5SDWT7YAF7L4WWANAGYINZAYXH", + "PZILRV7I2S6WKUSHKYRLA2JQY3", + "2G66TK2PZ5MOTAZDN7BFS3LAIH", + "QOLJ3WGJ6JS3FMMXBNTNAIKXVK", + "FMAL67YTHDCCYVZ5CRMN2XJPDN", + "UOTZDXTJKQ3YAIRKHTYNX6G55P", + "X3DLNPJ3V62LRHGEY4DTT35H3R", + "DKU7CHNXPB5QRZVGIQZW46XCKC", + "RAKBD4LQKEDTVDSK3DVTRWG23B", + "INTRA7BWHLVQMBRKBJNUSMF7MU", + "AUYRBNVCOYYHOHUYOOFIZ2FWMD", + "22EJVDEQ7PASLBAMTVKXOQP5RJ", + "3S6NATWA57SFTZEW7UZUOUYAEU", +} + +var benchmarkMessagesMLDSA87 = []string{ + "LQQPGPNUME6QDNDTQTS4BA7I7M", + "PTYEEJ7RMI6MXNN6PZH222Y6QI", + "R6DTHAADKNMEADDK5ECPNOTOAT", + "S2QM7VDC6UKRQNRETZMNAZ6SJT", + "EYULPTSJORQJCNYNYVHDFN4N3F", + "YETZNHZ75SXFU672VQ5WXYEPV2", + "KTSND3JGA4AN3PCMG4455JEXGR", + "JGE6HK37O6XMWZQZCHFUPNUEXP", + "CRYB2FZD2BYNANBFFO2HRZEHGZ", + "7MLNDZJ7OIEPBJZOMULOMQH2BA", + "4WQCNTIFVSX2DNALMWUKZRA6CI", + "Y5NK4OBDSDWC5WLL27CEEXYYOT", + "C4SSWSPBVCDAWJXH2CDMXR36LH", + "THDBKXRTKWJUGJMAAYTWTFMX7Z", + "NWXPUD4DAA6QOREW4AFFYQYQNG", + "3RQIJXMO7WYHBEBL3G6EOLNZNQ", + "R7JEOHFP2C7O4AVPRPRELXWOMM", + "LU6MWR7SZXVIKS54BY62X67NPA", + "FG2FFM4F2ECKHCSJ75KXK632JP", + "BF76ZDSVVUSYS5KK4FFD22YPS7", + "HCLBWZRLHEMYZLFWHLAN2BKCZ7", + "HGFVS4QC7AWXYPVRSWAK77KTQF", + "LUZ3C53PUUHBWCDJ7WAHK2UT3K", + "Y3WR6SMDUBW34N3MUT7EQYIJCV", + "F2X35AQTXVZBMPXTWNAAH4ZX2W", + "6MKFFDYWD6ZAKS3C6GRCRLZLRF", + "AFMZYYFRHKMQRNKU5UTSKQ74H6", + "TDTN7J3O367OVPWLESRNPLN4M2", + "WYMLD2X6N4CZ2RDOKF5CFTSYTG", + "UNPTSBLJ6HZRNR72T2VEEHCFX2", + "SNCM4R2P27AJOXBS67RMCARS3U", + "OU7QBE5QOXO7CIYTBJR3KOW2WK", + "2NNQOBQKZ2OD4ZAXI3SNEURYUP", + "YQTUPOYBT67XPCHIGKSGSKC3BZ", + "HGB4ZM3G76IXYWWCMVT3HONRIS", + "WZC6QUKRZZ2TOVA277JYKQITEW", + "XO2WT46A5HYL6CUJF7SGJ6YWOG", + "4QJA35PMYQIDRZ7ZHG7RLZJVGF", + "BMJZELWZ4I2UWXESU3NR6ATC4M", + "XWLFB7FN6D5PRY6YUXC5JUIBFM", + "WRAFFF27AVTIOYIBYA2IPTXI3R", + "VOXUTYTN2XZ362OJFO2R53UCUF", + "UHN73ARJ737WUJ6QYEI7U46OPO", + "3Y3K5E2A4ML3VYVNAFWEEIXTSN", + "QMU4322NKPRLE7JBGYFGS36H2S", + "NJAQTNCXPVDICTDVUKTPRCD2AX", + "OC373ZFBNV2H46T6OY3XRPSUHG", + "UBLAS6CDWE3A662MLKP7QDEOCC", + "BKFDLAL2RTPMERYVW3B7UJ5W3H", + "QFKFGXKGW5SAKLBAWQXUWW77OS", + "EJNUQHTLLOVB4ARETOGLY4WUTJ", + "N243OCMVLLAO6I2XLCYOIMQYGY", + "YRRFLWK7ZASUKYX7ZLQMW2PJ6X", + "3DGVPBWD2BIK6KQE65K72DNJNM", + "TJRYMNOAIW33VIHKLJG4GXAVUK", + "6DSRINAYXL34U54U355U7IVFGS", + "6CHA4MX7LVS77XKRWG7IYC3XVL", + "GM2CEGBEPBOHAPIOBUWJ4MJNTG", + "VJKHGBY33VUIJFEQLX3JVUNQBD", + "DTOHAD5M2KL46IZHE4TPLJWHTI", + "IYFG3UDN7ROOY2ZFSLM2BU2LMQ", + "A5OGJHPOE4PW6QSZYHZ5TKPGIC", + "FX4BCN67AEGCLUTLFPNDL3SQU5", + "MWIZQVOZOHTTBUXC3BEX62MNI5", + "BYHVJHBLK4O6LFSKEIQ3CAAKU7", + "QJU7P6KWSSKAA5GVA6RH4OV7MX", + "I3T3XM5Z5TAJHAYDQHFA2ZV7PU", + "L46MQCHV3TJ6FYIQQ2FCJXES74", + "QXZRQIYAJMXYR6PU3VDYGCIT5W", + "MFS53RR2XEYS22NYOJLGTHVTTM", + "FRWIWJRP4AQMXWX4WJ4WYVKM3E", + "X6GK6IGVLJWYSHLKHGXSW3TJDP", + "L5LPJ2HIWA4UY6G6FMZXGDEDAM", + "GD6FYOYUGDHXEQ5S2KLJEGNSN7", + "ODAL7ZRKXSPAAN5DVRBWJQCFQX", + "CV3QFBDXBPT3SCPJGUYSMDN6ZS", + "IGSLSACRZ6XID466KQIB4YNGYO", + "WZ2EACBN26RAML2S52YXRYP2OF", + "LB76VEVNOBYFMKFZ7SDFCBCHQE", + "TLFA7EU3JJFAP6EMUKNV2ZXRBM", + "SIIJF6OXAKRP25CBUYFBRCDDVP", + "TEPNI7TJ7HASJWIQMBS4VFLRQC", + "VK2JINYWEDV7IQFWH4OTAD4W5O", + "GILUH5AMVE4TM7EKPXJBZGT6EJ", + "DV7ALFRAW3TI4WMQQLDTO6RNHN", + "CAIB5G3NXC5ASPLFIWAFPVHS5B", + "MLFJXZUOAGN7EGPMXOOVTB2CL4", + "6MZYT3ANWHBOS67WGHZI3QPEAP", + "LVJDQB52C2PERSSQJRMRCJ4UBF", + "QY4VKAZAYQIZOX2L2VO2QHAQVC", + "UAA5SST2XA76JPKM3XOZ5RUHFI", + "VLZWF53JSQ6SCRUFDKVPXWAS4L", + "NX2DZIKMJIYXUNSAHFP23FHTBU", + "F5OAKDDDA34A2RPIKDPM5CYPMZ", + "E5PEP3ANIK2L4VLOST4NIYNKBD", + "IPBGFLHSMP4UFXF6XJX42T6CAL", + "XHPU7DBFTZB2TX5K34AD6DJTK3", + "2ZU7EJN2DG2UMT6HX5KGS2RFT6", + "SD5S7U34WSE4GBPKVDUDZLBIEH", + "WZFFL3BTQAV4VQMSAGCS45SGG3", + "QE7ZT2LI4CA5DLSVMHV6CP3E3V", + "YIWMS6AS72Z5N2ALZNFGCYC5QL", + "A4QJ5FNY54THAKBOB65K2JBIV7", + "6LORQGA3QO7TNADHEIINQZEE26", + "5V45M6RAKOZDMONYY4DIH3ZBL2", + "SVP7UYIZ5RTLWRKFLCWHAQV3Y2", + "C2UYQL2BBE4VLUJ3IFNFMHAN7O", + "P4DS44LGP2ERZB3OB7JISQKBXA", + "A6B4O5MWALOEHLILSVDOIXHQ4Z", + "DKQJTW5QF7KDZA3IR4X5R5F3CG", + "H6QFQX2C2QTH3YKEOO57SQS23J", + "DIF373ML2RWZMEOIVUHFXKUG7O", + "Z5PPIA3GJ74QXFFCOSUAQMN5YN", + "PM6XIDECSS5S77UXMB55VZHZSE", +} diff --git a/mldsa/mldsa.go b/mldsa/mldsa.go new file mode 100644 index 0000000..6206ec0 --- /dev/null +++ b/mldsa/mldsa.go @@ -0,0 +1,269 @@ +// Package mldsa implements the post-quantum ML-DSA signature scheme specified +// in FIPS 204. +package mldsa + +import ( + "crypto" + "crypto/subtle" + "errors" + "io" + + "github.com/go-webauthn/x/mldsa/internal/fips140/mldsa" + "github.com/go-webauthn/x/mldsa/mldsacrypto" +) + +const ( + PrivateKeySize = 32 + + MLDSA44PublicKeySize = 1312 + MLDSA65PublicKeySize = 1952 + MLDSA87PublicKeySize = 2592 + + MLDSA44SignatureSize = 2420 + MLDSA65SignatureSize = 3309 + MLDSA87SignatureSize = 4627 +) + +// Parameters represents one of the fixed parameter sets defined in FIPS 204. +// +// Most applications should use [MLDSA44]. +type Parameters struct { + name string + pubKeySize int + sigSize int +} + +var ( + mldsa44 = &Parameters{"ML-DSA-44", MLDSA44PublicKeySize, MLDSA44SignatureSize} + mldsa65 = &Parameters{"ML-DSA-65", MLDSA65PublicKeySize, MLDSA65SignatureSize} + mldsa87 = &Parameters{"ML-DSA-87", MLDSA87PublicKeySize, MLDSA87SignatureSize} +) + +// MLDSA44 returns the ML-DSA-44 parameter set defined in FIPS 204. +// +// Multiple invocations of this function will return the same value, which can +// be used for equality checks and switch statements. The returned value is safe +// for concurrent use. +func MLDSA44() *Parameters { return mldsa44 } + +// MLDSA65 returns the ML-DSA-65 parameter set defined in FIPS 204. +// +// Multiple invocations of this function will return the same value, which can +// be used for equality checks and switch statements. The returned value is safe +// for concurrent use. +func MLDSA65() *Parameters { return mldsa65 } + +// MLDSA87 returns the ML-DSA-87 parameter set defined in FIPS 204. +// +// Multiple invocations of this function will return the same value, which can +// be used for equality checks and switch statements. The returned value is safe +// for concurrent use. +func MLDSA87() *Parameters { return mldsa87 } + +// PublicKeySize returns the size of public keys for this parameter set, in bytes. +func (params *Parameters) PublicKeySize() int { return params.pubKeySize } + +// SignatureSize returns the size of signatures for this parameter set, in bytes. +func (params *Parameters) SignatureSize() int { return params.sigSize } + +// String returns the name of the parameter set, e.g. "ML-DSA-44". +func (params *Parameters) String() string { return params.name } + +// PrivateKey is an in-memory ML-DSA private key. It implements [crypto.Signer] +// and the informal extended [crypto.PrivateKey] interface. +// +// A PrivateKey is safe for concurrent use. +type PrivateKey struct { + key *mldsa.PrivateKey +} + +// GenerateKey generates a new random ML-DSA private key. +func GenerateKey(params *Parameters) (*PrivateKey, error) { + switch params { + case mldsa44: + return &PrivateKey{mldsa.GenerateKey44()}, nil + case mldsa65: + return &PrivateKey{mldsa.GenerateKey65()}, nil + case mldsa87: + return &PrivateKey{mldsa.GenerateKey87()}, nil + default: + return nil, errors.New("mldsa: invalid parameters") + } +} + +// NewPrivateKey creates a new ML-DSA private key from the given seed. +// +// The seed must be exactly [PrivateKeySize] bytes long. +func NewPrivateKey(params *Parameters, seed []byte) (*PrivateKey, error) { + var key *mldsa.PrivateKey + var err error + switch params { + case mldsa44: + key, err = mldsa.NewPrivateKey44(seed) + case mldsa65: + key, err = mldsa.NewPrivateKey65(seed) + case mldsa87: + key, err = mldsa.NewPrivateKey87(seed) + default: + return nil, errors.New("mldsa: invalid parameters") + } + if err != nil { + return nil, err + } + return &PrivateKey{key}, nil +} + +// Public returns the corresponding [PublicKey] for this private key. +// +// It implements the [crypto.Signer] interface. +func (sk *PrivateKey) Public() crypto.PublicKey { + return sk.PublicKey() +} + +// Equal reports whether sk and x are the same key (i.e. they are derived from +// the same seed). +// +// If x is not a *PrivateKey, Equal returns false. +func (sk *PrivateKey) Equal(x crypto.PrivateKey) bool { + other, ok := x.(*PrivateKey) + if !ok { + return false + } + return subtle.ConstantTimeCompare(sk.Bytes(), other.Bytes()) == 1 +} + +// PublicKey returns the corresponding [PublicKey] for this private key. +func (sk *PrivateKey) PublicKey() *PublicKey { + return &PublicKey{sk.key.PublicKey()} +} + +// Bytes returns the private key seed. +func (sk *PrivateKey) Bytes() []byte { + return sk.key.Bytes() +} + +// Sign returns a signature of the given message using this private key. +// +// If opts is nil or opts.HashFunc returns zero, the message is signed directly. +// If opts.HashFunc returns [crypto.MLDSAMu], the provided message must be a +// [pre-hashed μ message representative]. opts can be of type *[Options]. +// The io.Reader argument is ignored. +// +// [pre-hashed μ message representative]: https://www.rfc-editor.org/rfc/rfc9881.html#externalmu +func (sk *PrivateKey) Sign(_ io.Reader, message []byte, opts crypto.SignerOpts) (signature []byte, err error) { + switch { + case opts == nil || opts.HashFunc() == 0: + // Sign the message directly. + var context string + if opts, ok := opts.(*Options); ok { + context = opts.Context + } + return mldsa.Sign(sk.key, message, context) + case opts.HashFunc() == mldsacrypto.MLDSAMu: + // Sign the pre-hashed μ message representative. + return mldsa.SignExternalMu(sk.key, message) + default: + return nil, errors.New("mldsa: invalid SignerOpts.HashFunc") + } +} + +// SignDeterministic works like [PrivateKey.Sign], but the signature is +// deterministic. +func (sk *PrivateKey) SignDeterministic(message []byte, opts crypto.SignerOpts) (signature []byte, err error) { + switch { + case opts == nil || opts.HashFunc() == 0: + // Sign the message directly. + var context string + if opts, ok := opts.(*Options); ok { + context = opts.Context + } + return mldsa.SignDeterministic(sk.key, message, context) + case opts.HashFunc() == mldsacrypto.MLDSAMu: + // Sign the pre-hashed μ message representative. + return mldsa.SignExternalMuDeterministic(sk.key, message) + default: + return nil, errors.New("mldsa: invalid SignerOpts.HashFunc") + } +} + +// PublicKey is an ML-DSA public key. It implements the informal extended +// [crypto.PublicKey] interface. +// +// A PublicKey is safe for concurrent use. +type PublicKey struct { + key *mldsa.PublicKey +} + +// NewPublicKey creates a new ML-DSA public key from the given encoding. +func NewPublicKey(params *Parameters, seed []byte) (*PublicKey, error) { + var key *mldsa.PublicKey + var err error + switch params { + case mldsa44: + key, err = mldsa.NewPublicKey44(seed) + case mldsa65: + key, err = mldsa.NewPublicKey65(seed) + case mldsa87: + key, err = mldsa.NewPublicKey87(seed) + default: + return nil, errors.New("mldsa: invalid parameters") + } + if err != nil { + return nil, err + } + return &PublicKey{key}, nil +} + +// Bytes returns the public key encoding. +func (pk *PublicKey) Bytes() []byte { + return pk.key.Bytes() +} + +// Equal reports whether pk and x are the same key (i.e. they have the same +// encoding). +// +// If x is not a *PublicKey, Equal returns false. +func (pk *PublicKey) Equal(x crypto.PublicKey) bool { + other, ok := x.(*PublicKey) + if !ok { + return false + } + return subtle.ConstantTimeCompare(pk.Bytes(), other.Bytes()) == 1 +} + +// Parameters returns the parameters associated with this public key. +func (pk *PublicKey) Parameters() *Parameters { + switch pk.key.Parameters() { + case "ML-DSA-44": + return mldsa44 + case "ML-DSA-65": + return mldsa65 + case "ML-DSA-87": + return mldsa87 + default: + panic("mldsa: internal error: invalid parameters") + } +} + +// Verify reports whether signature is a valid signature of message by pk. +func Verify(pk *PublicKey, message []byte, signature []byte, opts *Options) error { + var context string + if opts != nil { + context = opts.Context + } + return mldsa.Verify(pk.key, message, signature, context) +} + +// Options contains additional options for signing and verifying ML-DSA signatures. +type Options struct { + // Context can be used to distinguish signatures created for different + // purposes. It must be at most 255 bytes long, and it is empty by default. + // + // The same context must be used when signing and verifying a signature. + Context string +} + +// HashFunc returns zero, to implement the [crypto.SignerOpts] interface. +func (opts *Options) HashFunc() crypto.Hash { + return 0 +} diff --git a/mldsa/mldsa_test.go b/mldsa/mldsa_test.go new file mode 100644 index 0000000..754fd13 --- /dev/null +++ b/mldsa/mldsa_test.go @@ -0,0 +1,558 @@ +package mldsa_test + +import ( + "bytes" + "crypto" + "crypto/rand" + "crypto/sha3" + "fmt" + "log" + "strings" + "testing" + + "github.com/go-webauthn/x/mldsa" + "github.com/go-webauthn/x/mldsa/mldsacrypto" +) + +func TestParameters(t *testing.T) { + tests := []struct { + params *mldsa.Parameters + name string + pubKeySize int + sigSize int + }{ + {mldsa.MLDSA44(), "ML-DSA-44", mldsa.MLDSA44PublicKeySize, mldsa.MLDSA44SignatureSize}, + {mldsa.MLDSA65(), "ML-DSA-65", mldsa.MLDSA65PublicKeySize, mldsa.MLDSA65SignatureSize}, + {mldsa.MLDSA87(), "ML-DSA-87", mldsa.MLDSA87PublicKeySize, mldsa.MLDSA87SignatureSize}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.params.String() != tt.name { + t.Errorf("String() = %q, want %q", tt.params.String(), tt.name) + } + if tt.params.PublicKeySize() != tt.pubKeySize { + t.Errorf("PublicKeySize() = %d, want %d", tt.params.PublicKeySize(), tt.pubKeySize) + } + if tt.params.SignatureSize() != tt.sigSize { + t.Errorf("SignatureSize() = %d, want %d", tt.params.SignatureSize(), tt.sigSize) + } + }) + } +} + +func TestParametersIdentity(t *testing.T) { + // Multiple invocations return the same value. + p44a, p44b := mldsa.MLDSA44(), mldsa.MLDSA44() + if p44a != p44b { + t.Error("MLDSA44() returned different values") + } + p65a, p65b := mldsa.MLDSA65(), mldsa.MLDSA65() + if p65a != p65b { + t.Error("MLDSA65() returned different values") + } + p87a, p87b := mldsa.MLDSA87(), mldsa.MLDSA87() + if p87a != p87b { + t.Error("MLDSA87() returned different values") + } + // Different parameter sets are not equal. + if p44a == p65a { + t.Error("MLDSA44() == MLDSA65()") + } +} + +func testAllParams(t *testing.T, f func(t *testing.T, params *mldsa.Parameters)) { + t.Run("ML-DSA-44", func(t *testing.T) { f(t, mldsa.MLDSA44()) }) + t.Run("ML-DSA-65", func(t *testing.T) { f(t, mldsa.MLDSA65()) }) + t.Run("ML-DSA-87", func(t *testing.T) { f(t, mldsa.MLDSA87()) }) +} + +func TestGenerateKey(t *testing.T) { + testAllParams(t, func(t *testing.T, params *mldsa.Parameters) { + sk, err := mldsa.GenerateKey(params) + if err != nil { + t.Fatal(err) + } + if len(sk.Bytes()) != mldsa.PrivateKeySize { + t.Errorf("seed length = %d, want %d", len(sk.Bytes()), mldsa.PrivateKeySize) + } + pk := sk.PublicKey() + if len(pk.Bytes()) != params.PublicKeySize() { + t.Errorf("public key length = %d, want %d", len(pk.Bytes()), params.PublicKeySize()) + } + if pk.Parameters() != params { + t.Errorf("Parameters() = %v, want %v", pk.Parameters(), params) + } + }) +} + +func TestGenerateKeyInvalidParams(t *testing.T) { + _, err := mldsa.GenerateKey(&mldsa.Parameters{}) + if err == nil { + t.Fatal("expected error for invalid parameters") + } +} + +func TestNewPrivateKey(t *testing.T) { + testAllParams(t, func(t *testing.T, params *mldsa.Parameters) { + seed := make([]byte, mldsa.PrivateKeySize) + rand.Read(seed) + sk, err := mldsa.NewPrivateKey(params, seed) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(sk.Bytes(), seed) { + t.Error("Bytes() does not match seed") + } + }) +} + +func TestNewPrivateKeyInvalidSeed(t *testing.T) { + testAllParams(t, func(t *testing.T, params *mldsa.Parameters) { + // Too short. + if _, err := mldsa.NewPrivateKey(params, make([]byte, 31)); err == nil { + t.Error("expected error for short seed") + } + // Too long. + if _, err := mldsa.NewPrivateKey(params, make([]byte, 33)); err == nil { + t.Error("expected error for long seed") + } + }) +} + +func TestNewPrivateKeyInvalidParams(t *testing.T) { + _, err := mldsa.NewPrivateKey(&mldsa.Parameters{}, make([]byte, mldsa.PrivateKeySize)) + if err == nil { + t.Fatal("expected error for invalid parameters") + } +} + +func TestKeyRoundTrip(t *testing.T) { + testAllParams(t, func(t *testing.T, params *mldsa.Parameters) { + sk1, err := mldsa.GenerateKey(params) + if err != nil { + t.Fatal(err) + } + sk2, err := mldsa.NewPrivateKey(params, sk1.Bytes()) + if err != nil { + t.Fatal(err) + } + if !sk1.Equal(sk2) { + t.Error("round-tripped private key is not equal") + } + if !sk1.PublicKey().Equal(sk2.PublicKey()) { + t.Error("public key from round-tripped private key is not equal") + } + + pk1 := sk1.PublicKey() + pk2, err := mldsa.NewPublicKey(params, pk1.Bytes()) + if err != nil { + t.Fatal(err) + } + if !pk1.Equal(pk2) { + t.Error("round-tripped public key is not equal") + } + if pk2.Parameters() != params { + t.Errorf("Parameters() = %v, want %v", pk2.Parameters(), params) + } + }) +} + +func TestNewPublicKeyInvalidEncoding(t *testing.T) { + testAllParams(t, func(t *testing.T, params *mldsa.Parameters) { + // Wrong length. + if _, err := mldsa.NewPublicKey(params, make([]byte, 10)); err == nil { + t.Error("expected error for wrong length") + } + }) +} + +func TestNewPublicKeyInvalidParams(t *testing.T) { + _, err := mldsa.NewPublicKey(&mldsa.Parameters{}, make([]byte, 100)) + if err == nil { + t.Fatal("expected error for invalid parameters") + } +} + +func TestPrivateKeyEqual(t *testing.T) { + sk1, _ := mldsa.GenerateKey(mldsa.MLDSA44()) + sk2, _ := mldsa.GenerateKey(mldsa.MLDSA44()) + + if !sk1.Equal(sk1) { + t.Error("key should be equal to itself") + } + if sk1.Equal(sk2) { + t.Error("different keys should not be equal") + } + if sk1.Equal("not a key") { + t.Error("should not be equal to non-key type") + } +} + +func TestPublicKeyEqual(t *testing.T) { + sk1, _ := mldsa.GenerateKey(mldsa.MLDSA44()) + sk2, _ := mldsa.GenerateKey(mldsa.MLDSA44()) + + pk1 := sk1.PublicKey() + pk2 := sk2.PublicKey() + + if !pk1.Equal(pk1) { + t.Error("key should be equal to itself") + } + if pk1.Equal(pk2) { + t.Error("different keys should not be equal") + } + if pk1.Equal("not a key") { + t.Error("should not be equal to non-key type") + } +} + +func TestPublicMethod(t *testing.T) { + sk, _ := mldsa.GenerateKey(mldsa.MLDSA44()) + pub := sk.Public() + pk, ok := pub.(*mldsa.PublicKey) + if !ok { + t.Fatalf("Public() returned %T, want *mldsa.PublicKey", pub) + } + if !pk.Equal(sk.PublicKey()) { + t.Error("Public() and PublicKey() returned different keys") + } +} + +func TestSignAndVerify(t *testing.T) { + testAllParams(t, func(t *testing.T, params *mldsa.Parameters) { + sk, err := mldsa.GenerateKey(params) + if err != nil { + t.Fatal(err) + } + msg := []byte("test message") + sig, err := sk.Sign(nil, msg, nil) + if err != nil { + t.Fatal(err) + } + if len(sig) != params.SignatureSize() { + t.Errorf("signature length = %d, want %d", len(sig), params.SignatureSize()) + } + if err := mldsa.Verify(sk.PublicKey(), msg, sig, nil); err != nil { + t.Errorf("Verify failed: %v", err) + } + }) +} + +func TestSignAndVerifyWithContext(t *testing.T) { + sk, _ := mldsa.GenerateKey(mldsa.MLDSA44()) + msg := []byte("test message") + + sig, err := sk.Sign(nil, msg, &mldsa.Options{Context: "test context"}) + if err != nil { + t.Fatal(err) + } + + // Verify with correct context. + if err := mldsa.Verify(sk.PublicKey(), msg, sig, &mldsa.Options{Context: "test context"}); err != nil { + t.Errorf("Verify failed: %v", err) + } + // Verify with wrong context. + if err := mldsa.Verify(sk.PublicKey(), msg, sig, nil); err == nil { + t.Error("expected verification failure with wrong context") + } + if err := mldsa.Verify(sk.PublicKey(), msg, sig, &mldsa.Options{Context: "wrong"}); err == nil { + t.Error("expected verification failure with wrong context") + } +} + +func TestSignContextTooLong(t *testing.T) { + sk, _ := mldsa.GenerateKey(mldsa.MLDSA44()) + longCtx := strings.Repeat("x", 256) + _, err := sk.Sign(nil, []byte("msg"), &mldsa.Options{Context: longCtx}) + if err == nil { + t.Fatal("expected error for context too long") + } +} + +func TestSignDeterministic(t *testing.T) { + testAllParams(t, func(t *testing.T, params *mldsa.Parameters) { + sk, _ := mldsa.GenerateKey(params) + msg := []byte("test message") + + sig1, err := sk.SignDeterministic(msg, nil) + if err != nil { + t.Fatal(err) + } + sig2, err := sk.SignDeterministic(msg, nil) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(sig1, sig2) { + t.Error("deterministic signatures differ") + } + if err := mldsa.Verify(sk.PublicKey(), msg, sig1, nil); err != nil { + t.Errorf("Verify failed: %v", err) + } + }) +} + +func TestSignDeterministicWithContext(t *testing.T) { + sk, _ := mldsa.GenerateKey(mldsa.MLDSA44()) + msg := []byte("test message") + + sig1, err := sk.SignDeterministic(msg, &mldsa.Options{Context: "ctx"}) + if err != nil { + t.Fatal(err) + } + sig2, err := sk.SignDeterministic(msg, &mldsa.Options{Context: "ctx"}) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(sig1, sig2) { + t.Error("deterministic signatures with same context differ") + } + // Different context should produce different signature. + sig3, err := sk.SignDeterministic(msg, nil) + if err != nil { + t.Fatal(err) + } + if bytes.Equal(sig1, sig3) { + t.Error("deterministic signatures with different context should differ") + } +} + +func TestSignDeterministicContextTooLong(t *testing.T) { + sk, _ := mldsa.GenerateKey(mldsa.MLDSA44()) + longCtx := strings.Repeat("x", 256) + _, err := sk.SignDeterministic([]byte("msg"), &mldsa.Options{Context: longCtx}) + if err == nil { + t.Fatal("expected error for context too long") + } +} + +func TestSignInvalidHashFunc(t *testing.T) { + sk, _ := mldsa.GenerateKey(mldsa.MLDSA44()) + _, err := sk.Sign(nil, []byte("msg"), crypto.SHA256) + if err == nil { + t.Fatal("expected error for invalid HashFunc") + } +} + +func TestSignDeterministicInvalidHashFunc(t *testing.T) { + sk, _ := mldsa.GenerateKey(mldsa.MLDSA44()) + _, err := sk.SignDeterministic([]byte("msg"), crypto.SHA256) + if err == nil { + t.Fatal("expected error for invalid HashFunc") + } +} + +// computeMu computes the μ message representative as specified in FIPS 204. +// μ = SHAKE256(tr || 0x00 || len(ctx) || ctx || msg), where +// tr = SHAKE256(publicKeyBytes) is 64 bytes. +func computeMu(pk *mldsa.PublicKey, msg []byte, context string) []byte { + H := sha3.NewSHAKE256() + H.Write(pk.Bytes()) + var tr [64]byte + H.Read(tr[:]) + + H.Reset() + H.Write(tr[:]) + H.Write([]byte{0x00}) // ML-DSA domain separator + H.Write([]byte{byte(len(context))}) + H.Write([]byte(context)) + H.Write(msg) + mu := make([]byte, 64) + H.Read(mu) + return mu +} + +func TestSignExternalMu(t *testing.T) { + testAllParams(t, func(t *testing.T, params *mldsa.Parameters) { + sk, _ := mldsa.GenerateKey(params) + msg := []byte("test message") + + mu := computeMu(sk.PublicKey(), msg, "") + + sig, err := sk.Sign(nil, mu, mldsacrypto.MLDSAMu) + if err != nil { + t.Fatal(err) + } + // The signature produced via external mu should verify against + // the original message via the standard Verify. + if err := mldsa.Verify(sk.PublicKey(), msg, sig, nil); err != nil { + t.Errorf("Verify failed: %v", err) + } + }) +} + +func TestSignExternalMuWithContext(t *testing.T) { + sk, _ := mldsa.GenerateKey(mldsa.MLDSA44()) + msg := []byte("test message") + + mu := computeMu(sk.PublicKey(), msg, "my context") + + sig, err := sk.Sign(nil, mu, mldsacrypto.MLDSAMu) + if err != nil { + t.Fatal(err) + } + if err := mldsa.Verify(sk.PublicKey(), msg, sig, &mldsa.Options{Context: "my context"}); err != nil { + t.Errorf("Verify failed: %v", err) + } + // Should fail with wrong context. + if err := mldsa.Verify(sk.PublicKey(), msg, sig, nil); err == nil { + t.Error("expected verification failure with wrong context") + } +} + +func TestSignExternalMuDeterministic(t *testing.T) { + sk, _ := mldsa.GenerateKey(mldsa.MLDSA44()) + msg := []byte("test message") + mu := computeMu(sk.PublicKey(), msg, "") + + sig1, err := sk.SignDeterministic(mu, mldsacrypto.MLDSAMu) + if err != nil { + t.Fatal(err) + } + sig2, err := sk.SignDeterministic(mu, mldsacrypto.MLDSAMu) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(sig1, sig2) { + t.Error("deterministic external mu signatures differ") + } + if err := mldsa.Verify(sk.PublicKey(), msg, sig1, nil); err != nil { + t.Errorf("Verify failed: %v", err) + } +} + +func TestSignExternalMuInvalidLength(t *testing.T) { + sk, _ := mldsa.GenerateKey(mldsa.MLDSA44()) + // μ must be exactly 64 bytes. + _, err := sk.Sign(nil, make([]byte, 32), mldsacrypto.MLDSAMu) + if err == nil { + t.Fatal("expected error for invalid mu length") + } + _, err = sk.SignDeterministic(make([]byte, 32), mldsacrypto.MLDSAMu) + if err == nil { + t.Fatal("expected error for invalid mu length") + } +} + +func TestVerifyWrongKey(t *testing.T) { + sk1, _ := mldsa.GenerateKey(mldsa.MLDSA44()) + sk2, _ := mldsa.GenerateKey(mldsa.MLDSA44()) + msg := []byte("test message") + sig, _ := sk1.Sign(nil, msg, nil) + + if err := mldsa.Verify(sk2.PublicKey(), msg, sig, nil); err == nil { + t.Error("expected verification failure with wrong key") + } +} + +func TestVerifyWrongMessage(t *testing.T) { + sk, _ := mldsa.GenerateKey(mldsa.MLDSA44()) + sig, _ := sk.Sign(nil, []byte("message 1"), nil) + + if err := mldsa.Verify(sk.PublicKey(), []byte("message 2"), sig, nil); err == nil { + t.Error("expected verification failure with wrong message") + } +} + +func TestVerifyTruncatedSignature(t *testing.T) { + sk, _ := mldsa.GenerateKey(mldsa.MLDSA44()) + msg := []byte("test message") + sig, _ := sk.Sign(nil, msg, nil) + + if err := mldsa.Verify(sk.PublicKey(), msg, sig[:len(sig)-1], nil); err == nil { + t.Error("expected verification failure with truncated signature") + } +} + +func TestOptionsHashFunc(t *testing.T) { + opts := &mldsa.Options{} + if opts.HashFunc() != 0 { + t.Errorf("HashFunc() = %d, want 0", opts.HashFunc()) + } +} + +func TestCryptoSignerInterface(t *testing.T) { + sk, _ := mldsa.GenerateKey(mldsa.MLDSA44()) + var _ crypto.Signer = sk +} + +func ExamplePrivateKey_Sign_withContext() { + sk, err := mldsa.GenerateKey(mldsa.MLDSA44()) + if err != nil { + log.Fatal(err) + } + + message := []byte("hello, world") + sig, err := sk.Sign(nil, message, &mldsa.Options{Context: "example"}) + if err != nil { + log.Fatal(err) + } + + if err := mldsa.Verify(sk.PublicKey(), message, sig, &mldsa.Options{Context: "example"}); err != nil { + log.Fatal(err) + } + fmt.Println("signature verified") + // Output: signature verified +} + +func ExamplePrivateKey_SignDeterministic() { + sk, err := mldsa.GenerateKey(mldsa.MLDSA44()) + if err != nil { + log.Fatal(err) + } + + message := []byte("hello, world") + sig, err := sk.SignDeterministic(message, nil) + if err != nil { + log.Fatal(err) + } + + if err := mldsa.Verify(sk.PublicKey(), message, sig, nil); err != nil { + log.Fatal(err) + } + fmt.Println("signature verified") + // Output: signature verified +} + +func ExamplePrivateKey_Sign_externalMu() { + sk, err := mldsa.GenerateKey(mldsa.MLDSA44()) + if err != nil { + log.Fatal(err) + } + pk := sk.PublicKey() + + // Compute μ externally, as specified in FIPS 204. + // + // This is useful when the message is large, because μ can be computed + // incrementally by the caller without buffering the full message. + // + // First, compute tr = SHAKE256(publicKey). + H := sha3.NewSHAKE256() + H.Write(pk.Bytes()) + var tr [64]byte + H.Read(tr[:]) + + // Then, compute μ = SHAKE256(tr || 0x00 || len(ctx) || ctx || msg). + // The second byte is 0x00 for ML-DSA (as opposed to HashML-DSA) and ctx + // is the context string, empty by default. + message := []byte("hello, world") + H.Reset() + H.Write(tr[:]) + H.Write([]byte{0x00}) // ML-DSA domain separator + H.Write([]byte{0x00}) // context length (0 for empty context) + H.Write(message) + mu := make([]byte, 64) + H.Read(mu) + + // Sign the pre-computed μ by passing MLDSAMu as the hash function. + sig, err := sk.Sign(nil, mu, mldsacrypto.MLDSAMu) + if err != nil { + log.Fatal(err) + } + + // Verify against the original message using the standard Verify function. + if err := mldsa.Verify(pk, message, sig, nil); err != nil { + log.Fatal(err) + } + fmt.Println("signature verified") + // Output: signature verified +} diff --git a/mldsa/mldsacrypto/mldsamu.go b/mldsa/mldsacrypto/mldsamu.go new file mode 100644 index 0000000..7ce3083 --- /dev/null +++ b/mldsa/mldsacrypto/mldsamu.go @@ -0,0 +1,13 @@ +// Package mldsacrypto is a stand-in for the standard library's crypto package, +// until MLDSAMu is added there, at which point this package will become a +// wrapper. +package mldsacrypto + +import "crypto" + +// MLDSAMu is a function that produces a [pre-hashed μ message representative]. +// It has no implementation, but is used a [crypto.SignerOpts.HashFunc] return +// value for [mldsa.PrivateKey.Sign]. +// +// [pre-hashed μ message representative]: https://www.rfc-editor.org/rfc/rfc9881.html#externalmu +const MLDSAMu crypto.Hash = 0xABCDEF12