Skip to content

Commit 4b23092

Browse files
authored
feat: Shamir's secret sharing (#49)
1 parent 9691cbb commit 4b23092

5 files changed

Lines changed: 144 additions & 89 deletions

File tree

pactus/crypto/sss/__init__.py

Whitespace-only changes.

pactus/crypto/sss/sss.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""
2+
The following Python implementation of Shamir's secret sharing is
3+
released into the Public Domain under the terms of CC0 and OWFa:
4+
https://creativecommons.org/publicdomain/zero/1.0/
5+
http://www.openwebfoundation.org/legal/the-owf-1-0-agreements/owfa-1-0.
6+
7+
See the bottom few lines for usage. Tested on Python 2 and 3.
8+
"""
9+
from __future__ import annotations
10+
11+
import functools
12+
import random
13+
14+
_RINT = functools.partial(random.SystemRandom().randint, 0)
15+
16+
17+
def _eval_at(poly: list[int], x: int, prime: int) -> int:
18+
"""
19+
Evaluate polynomial (coefficient tuple) at x, used to generate a
20+
shamir pool in make_random_shares below.
21+
"""
22+
accum = 0
23+
for coeff in reversed(poly):
24+
accum *= x
25+
accum += coeff
26+
accum %= prime
27+
28+
return accum
29+
30+
31+
def _extended_gcd(a: int, b: int) -> int:
32+
"""
33+
Division in integers modulus p means finding the inverse of the
34+
denominator modulo p and then multiplying the numerator by this
35+
inverse (Note: inverse of A is B such that A*B % p == 1). This can
36+
be computed via the extended Euclidean algorithm
37+
http://en.wikipedia.org/wiki/Modular_multiplicative_inverse#Computation.
38+
"""
39+
x = 0
40+
last_x = 1
41+
y = 1
42+
last_y = 0
43+
while b != 0:
44+
quot = a // b
45+
a, b = b, a % b
46+
x, last_x = last_x - quot * x, x
47+
y, last_y = last_y - quot * y, y
48+
49+
return last_x, last_y
50+
51+
52+
def _divmod(num: int, den: int, p: int) -> int:
53+
"""
54+
Compute num / den modulo prime p.
55+
56+
To explain this, the result will be such that:
57+
den * _divmod(num, den, p) % p == num
58+
"""
59+
inv, _ = _extended_gcd(den, p)
60+
61+
return num * inv
62+
63+
64+
def _lagrange_interpolate(x: int, x_s: list[int], y_s: list[int], p: int) -> int:
65+
"""
66+
Find the y-value for the given x, given n (x, y) points;
67+
k points will define a polynomial of up to kth order.
68+
"""
69+
k = len(x_s)
70+
if k != len(set(x_s)):
71+
msg = "points must be distinct"
72+
raise ValueError(msg)
73+
74+
def _pi(vals: list[int]) -> int: # upper-case PI -- product of inputs
75+
accum = 1
76+
for v in vals:
77+
accum *= v
78+
return accum
79+
80+
nums = [] # avoid inexact division
81+
dens = []
82+
for i in range(k):
83+
others = list(x_s)
84+
cur = others.pop(i)
85+
nums.append(_pi(x - o for o in others))
86+
dens.append(_pi(cur - o for o in others))
87+
88+
den = _pi(dens)
89+
num = sum([_divmod(nums[i] * den * y_s[i] % p, dens[i], p) for i in range(k)])
90+
91+
return (_divmod(num, den, p) + p) % p
92+
93+
94+
def make_random_shares(secret: int, minimum: int, shares: int, prime: int) -> list[tuple[int, int]]:
95+
"""Generate a random shamir pool for a given secret, returns share points."""
96+
if minimum > shares:
97+
msg = "Pool secret would be irrecoverable."
98+
raise ValueError(msg)
99+
poly = [secret] + [_RINT(prime - 1) for i in range(minimum - 1)]
100+
return [(i, _eval_at(poly, i, prime)) for i in range(1, shares + 1)]
101+
102+
103+
104+
def recover_secret(shares: list[tuple[int, int]], prime: int) -> int:
105+
"""
106+
Recover the secret from share points
107+
(points (x,y) on the polynomial).
108+
"""
109+
if len(shares) < 2:
110+
msg = "need at least two shares"
111+
raise ValueError(msg)
112+
113+
x_s, y_s = zip(*shares)
114+
115+
return _lagrange_interpolate(0, x_s, y_s, prime)

pactus/utils/utils.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,3 @@ def encode_from_base256_with_type(hrp: str, typ: str, data: bytes) -> str:
1717
converted = bech32m.convertbits(list(data), 8, 5, pad=True)
1818
converted = [typ, *converted]
1919
return bech32m.bech32_encode(hrp, converted, bech32m.Encoding.BECH32M)
20-
21-
22-
def evaluate_polynomial(c: list[int], x: int, mod: int) -> int | None:
23-
"""
24-
Evaluate the polynomial f(x) = c[0] + c[1] * x + c[2] * x^2 + ... + c[n-1] * x^(n-1).
25-
26-
Args:
27-
c: List of polynomial coefficients (c[0] is the constant term)
28-
x: The value at which to evaluate the polynomial
29-
mod: The modulus to use for the evaluation
30-
31-
Returns:
32-
The computed value f(x) if success, None otherwise
33-
34-
"""
35-
if not c:
36-
return None
37-
38-
if len(c) == 1:
39-
return c[0]
40-
41-
y = c[-1]
42-
for i in range(len(c) - 2, -1, -1):
43-
y = (y * x + c[i]) % mod
44-
45-
return y

tests/test_crypto_sss.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import unittest
2+
from pactus.crypto.sss import sss
3+
4+
5+
class TestEvaluatePolynomial(unittest.TestCase):
6+
def test_wikipedia_example(self):
7+
# https://en.wikipedia.org/wiki/Shamir%27s_secret_sharing
8+
self.assertEqual(sss._eval_at([1234, 166, 94], 1, 2**127 - 1), 1494)
9+
self.assertEqual(sss._eval_at([1234, 166, 94], 2, 2**127 - 1), 1942)
10+
self.assertEqual(sss._eval_at([1234, 166, 94], 3, 2**127 - 1), 2578)
11+
self.assertEqual(sss._eval_at([1234, 166, 94], 4, 2**127 - 1), 3402)
12+
self.assertEqual(sss._eval_at([1234, 166, 94], 5, 2**127 - 1), 4414)
13+
self.assertEqual(sss._eval_at([1234, 166, 94], 6, 2**127 - 1), 5614)
14+
15+
16+
class TestRecover(unittest.TestCase):
17+
def test_recover_secret_1(self):
18+
shares = [(1, 1494), (2, 1942), (3, 2578)]
19+
prime = 2**127 - 1
20+
self.assertEqual(sss.recover_secret(shares, prime), 1234)
21+
22+
def test_recover_secret_2(self):
23+
shares = [(1, 1494), (3, 2578), (6, 5614)]
24+
prime = 2**127 - 1
25+
self.assertEqual(sss.recover_secret(shares, prime), 1234)
26+
27+
28+
if __name__ == "__main__":
29+
unittest.main()

tests/test_utils.py

Lines changed: 0 additions & 63 deletions
This file was deleted.

0 commit comments

Comments
 (0)