Last active
March 13, 2019 18:19
-
-
Save Goheeca/ae59f472e41e705ba72113a28760831b to your computer and use it in GitHub Desktop.
Shamir's Secret Sharing Scheme
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fractions import Fraction | |
import random | |
from itertools import zip_longest | |
class Field(object): | |
def __init__(self): | |
self.characteristic = None | |
def __call__(self, value): | |
raise NotImplementedError | |
def zero(self): | |
raise NotImplementedError | |
def add(self, a, b): | |
raise NotImplementedError | |
def neg(self, a): | |
raise NotImplementedError | |
def one(self): | |
raise NotImplementedError | |
def mul(self, a, b): | |
raise NotImplementedError | |
def inv(self, a): | |
raise NotImplementedError | |
def rand(self): | |
raise NotImplementedError | |
def value(self, val): | |
raise NotImplementedError | |
class FieldElem(object): | |
def __init__(self, field, value): | |
self._field = field | |
self.value = value | |
def __hash__(self): | |
return self.value.__hash__() | |
def __str__(self): | |
return str(self.value) | |
def __repr__(self): | |
return self._field.__repr__() + "<" + self.__str__() + ">" | |
def _wrap(self, value): | |
return FieldElem(self._field, value) | |
def _check(self, other, fn=None): | |
if isinstance(other, self.__class__): | |
assert other._field == self._field, "Elements are from different fields." | |
other = other.value | |
else: | |
other = self._field.__call__(other).value | |
return self.value, (other if fn is None else fn(other)) | |
def __add__(self, other): | |
return self._wrap(self._field.add(*self._check(other))) | |
def __neg__(self): | |
return self._wrap(self._field.neg(self.value)) | |
def __sub__(self, other): | |
return self._wrap(self._field.add(*self._check(other, self._field.neg))) | |
def __mul__(self, other): | |
return self._wrap(self._field.mul(*self._check(other))) | |
def __invert__(self): | |
return self._wrap(self._field.inv(self.value)) | |
def __truediv__(self, other): | |
return self._wrap(self._field.mul(*self._check(other, self._field.inv))) | |
def __iadd__(self, other): | |
self.value = self._field.add(*self._check(other)) | |
return self | |
def __isub__(self, other): | |
self.value = self._field.add(*self._check(other, self._field.neg)) | |
return self | |
def __imul__(self, other): | |
self.value = self._field.mul(*self._check(other)) | |
return self | |
def __idiv__(self, other): | |
self.value = self._field.mul(*self._check(other, self._field.inv)) | |
return self | |
__radd__ = __add__ | |
__rmul__ = __mul__ | |
def __rsub__(self, other): | |
s, o = self._check(other) | |
return self._wrap(self._field.add(self._field.neg(s), o)) | |
def __rtruediv__(self, other): | |
s, o = self._check(other) | |
return self._wrap(self._field.mul(self._field.inv(s), o)) | |
def __eq__(self, other): | |
s, o = self._check(other) | |
return s.__eq__(o) | |
def __ne__(self, other): | |
s, o = self._check(other) | |
return s.__ne__(o) | |
def __call__(self): | |
return self._field.value(self.value) | |
class RationalField(Field): | |
def __call__(self, value): | |
return FieldElem(self, Fraction(value)) | |
def zero(self): | |
return FieldElem(self, Fraction()) | |
def add(self, a, b): | |
return a + b | |
def neg(self, a): | |
return -a | |
def one(self): | |
return FieldElem(self, Fraction(1)) | |
def mul(self, a, b): | |
return a * b | |
def inv(self, a): | |
return 1 / a | |
def rand(self): | |
return self.__call__(random.gauss(0, 1)) | |
def value(self, val): | |
return int(val) | |
def __repr__(self): | |
return "Q" | |
rationalField = RationalField() | |
def xgcd(a, b, zero=lambda: 0, one=lambda: 1, mod=lambda x: x): | |
"""return (g, x, y) such that a*x + b*y = g = gcd(a, b)""" | |
x0, x1, y0, y1 = zero(), one(), one(), zero() | |
z = zero() | |
while a != z: | |
q, b, a = mod(b // a), a, mod(b % a) | |
y0, y1 = y1, mod(y0 - q * y1) | |
x0, x1 = x1, mod(x0 - q * x1) | |
return b, x0, y0 | |
def mulinv(a, b, zero=lambda: 0, one=lambda: 1, mod=lambda x: x): | |
"""return x such that (x * a) % b == 1""" | |
g, x, _ = xgcd(a, b, zero, one, mod) | |
o = one() | |
if g == o: | |
return x % b | |
class GaloisPrimeField(Field): | |
def __init__(self, prime): | |
self._prime = prime | |
self.characteristic = prime | |
def __call__(self, value): | |
return FieldElem(self, int(value) % self._prime) | |
def zero(self): | |
return FieldElem(self, 0) | |
def add(self, a, b): | |
return (a + b) % self._prime | |
def neg(self, a): | |
return -a % self._prime | |
def one(self): | |
return FieldElem(self, 1) | |
def mul(self, a, b): | |
return (a * b) % self._prime | |
def inv(self, a): | |
return mulinv(a, self._prime) | |
def rand(self): | |
return self.__call__(random.randint(0, self.characteristic)) | |
def value(self, val): | |
return val | |
def __repr__(self): | |
return f"GF({self._prime})" | |
class Polynom(list): | |
def __init__(self, list_): | |
if isinstance(list_, list): | |
super().__init__(list_) | |
else: | |
super().__init__([list_]) | |
def __hash__(self): | |
hash_ = 0 | |
for coeff in self[:self.rank()]: | |
hash_ ^= hash(coeff) | |
return hash_ | |
def elem_mod(self, m): | |
return Polynom([a % m for a in self]) | |
def rank(self, other=None): | |
if other is None: | |
other = self | |
return next((len(other) - idx - 1 for idx, val in enumerate(reversed(other)) if val != 0), 0) | |
def is_zero(self): | |
for coeff in self: | |
if coeff != 0: | |
return False | |
return True | |
def _pol(self, object_): | |
if isinstance(object_, self.__class__): | |
return object_ | |
else: | |
return self.__class__(object_) | |
def _list(self, object_): | |
if isinstance(object_, self.__class__): | |
return list(object_) | |
else: | |
return object_ | |
def __add__(self, other): | |
return Polynom([a + b for a, b in zip_longest(self, self._pol(other), fillvalue=0)]) | |
def __neg__(self): | |
return Polynom([-a for a in self]) | |
def __sub__(self, other): | |
return self.__add__(self._pol(other).__neg__()) | |
def __mul__(self, other): | |
res = Polynom((len(self) + len(other) - 1) * [0]) | |
for i in range(len(self)): | |
for j in range(len(other)): | |
res[i+j] += self[i] * other[j] | |
return res | |
def __floordiv__(self, other): | |
other = self._pol(other) | |
if self.rank() < other.rank(): | |
return Polynom([]) | |
if other.is_zero(): | |
raise ZeroDivisionError | |
rank_diff = self.rank() - other.rank() | |
result = Polynom((rank_diff + 1) * [0]) | |
tmp = Polynom(self) | |
for result_head in range(rank_diff, -1, -1): | |
coeff = None | |
head = self.rank() - (rank_diff - result_head) | |
for k, l in zip(range(head, head - other.rank() - 1, -1), range(other.rank(), -1, -1)): | |
if coeff is None: | |
coeff = tmp[k] // other[l] | |
result[result_head] = coeff | |
tmp[k] -= coeff * other[l] | |
return result | |
def __mod__(self, other): | |
other = self._pol(other) | |
if self.rank() < other.rank(): | |
return self | |
if other.is_zero(): | |
raise ZeroDivisionError | |
result = Polynom(self) | |
for result_head in range(self.rank(), other.rank() - 1, -1): | |
coeff = None | |
for k, l in zip(range(result_head, result_head - other.rank() - 1, -1), range(other.rank(), -1, -1)): | |
if coeff is None: | |
coeff = result[k] // other[l] | |
result[k] -= coeff * other[l] | |
return Polynom(result[:result.rank()+1]) | |
def __divmod__(self, other): | |
# TODO make it effective | |
return (self.__floordiv__(other), self.__mod__(other)) | |
def __iadd__(self, other): | |
tmp = self.__add__(other) | |
self.clear() | |
self.extend(tmp) | |
return self | |
def __isub__(self, other): | |
tmp = self.__sub__(other) | |
self.clear() | |
self.extend(tmp) | |
return self | |
def __imul__(self, other): | |
tmp = self.__mul__(other) | |
self.clear() | |
self.extend(tmp) | |
return self | |
def __ifloordiv__(self, other): | |
tmp = self.__floordiv__(other) | |
self.clear() | |
self.extend(tmp) | |
return self | |
def __imod__(self, other): | |
tmp = self.__mod__(other) | |
self.clear() | |
self.extend(tmp) | |
return self | |
__radd__ = __add__ | |
__rmul__ = __mul__ | |
def __rsub__(self, other): | |
return self.__neg__().__add__(self._pol(other)) | |
def __rfloordiv__(self, other): | |
return self._pol(other).__floordiv__(self) | |
def __rmod__(self, other): | |
return self._pol(other).__mod__(self) | |
def __rdivmod__(self, other): | |
return self._pol(other).__divmod__(self) | |
def __eq__(self, other): | |
for a, b in zip_longest(self, other, fillvalue=0): | |
if a != b: | |
return False | |
return True | |
def __ne__(self, other): | |
return not self.__eq__(other) | |
class GaloisExtensionField(Field): | |
def __init__(self, prime, power, polynom): | |
self._prime = prime | |
self._power = power | |
self._polynom = Polynom(polynom) | |
self.characteristic = prime ** power | |
def __call__(self, value): | |
if isinstance(value, int): | |
value = self._convert(value) | |
value = Polynom([int(elem) % self._prime for elem in value] + (self._power - len(value)) * [0]) | |
return FieldElem(self, value) | |
def zero(self): | |
return FieldElem(self, self._canonical([])) | |
def add(self, a, b): | |
return self._canonical(a + b) | |
def neg(self, a): | |
return self._canonical(-a) | |
def one(self): | |
return FieldElem(self, self._canonical([1])) | |
def mul(self, a, b): | |
return self._canonical((a * b) % self._polynom) | |
def inv(self, a): | |
return self._canonical(mulinv(a, self._polynom, zero=lambda: self._canonical([]), one=lambda: self._canonical([1]), mod=lambda x: x.elem_mod(self._prime))) | |
def rand(self): | |
value = [random.randint(0, self._prime) for i in range(self._power)] | |
return self.__call__(value) | |
def value(self, val): | |
result = 0 | |
for coeff in reversed(val): | |
result *= self._prime | |
result += coeff | |
return result | |
def _convert(self, idx): | |
tmp = [] | |
while idx != 0: | |
tmp.append(idx % self._prime) | |
idx //= self._prime | |
return tmp | |
def _canonical(self, value): | |
return Polynom(value[:self._power] + (self._power - len(value)) * [0]).elem_mod(self._prime) | |
def __repr__(self): | |
return f"GF({self._prime}**{self._power}/{self._polynom})" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import shamir | |
import finite | |
import itertools | |
def powerset(iterable): | |
s = list(iterable) | |
return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(len(s)+1)) | |
def shamir_test(m, s, field=finite.rationalField): | |
secret, shares, poly = shamir._make_random_shares(minimum=m, shares=s, field=field) | |
print(f"Polynom: {poly}") | |
print(f"Secret: {secret}") | |
print(f"Shares: {shares}") | |
for idxs in powerset(range(s)): | |
combination = [shares[idx] for idx in idxs] | |
recovered = shamir.recover_secret(combination, field) | |
ok = "✓ " if recovered == secret else "✗ " | |
print(ok + f"Combination: {idxs} Value: {recovered}") | |
gp521 = 6864797660130609714981900799081393217269435300143305409394463459185543183397656052122559640661454554977296311391480858037121987999716643812574028291115057151 | |
prime30 = 280829369862134719390036617067 | |
gf521 = finite.GaloisPrimeField(gp521) | |
gf30 = finite.GaloisPrimeField(prime30) | |
gf256 = finite.GaloisExtensionField(2, 8, [1, 1, 0, 1, 1, 0, 0, 0, 1]) #rijndael | |
def main(): | |
print("Q (rational numbers):") | |
shamir_test(3, 4) | |
print() | |
print("F_280829369862134719390036617067 (modular numbers):") | |
shamir_test(3, 4, gf30) | |
print() | |
print("F_256 Rijndael (polynomials):") | |
shamir_test(3, 4, gf256) | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from finite import * | |
def _polynom_value(poly, x, field=rationalField): | |
value = field.zero() | |
for coeff in reversed(poly): | |
value *= x | |
value += coeff | |
return value | |
def _make_random_shares(minimum, shares, field=rationalField): | |
if minimum > shares: | |
raise ValueError("pool secret would be irrecoverable") | |
poly = [field.rand() for i in range(minimum)] | |
points = [(field(i), _polynom_value(poly, field(i), field)) | |
for i in range(1, shares + 1)] | |
return poly[0], points, poly | |
def make_random_shares(minimum, shares, field=rationalField): | |
secret, shares, poly = _make_random_shares(minimum, shares, field) | |
return secret, shares | |
def _lagrange_interpolate(x, x_s, y_s, field=rationalField): | |
k = len(x_s) | |
assert k == len(set(x_s)), "points must be distinct" | |
def product(vals): | |
value = field.one() | |
for v in vals: | |
value *= v | |
return value | |
nums = [] | |
dens = [] | |
for i in range(k): | |
others = list(x_s) | |
cur = others.pop(i) | |
nums.append(product(x - o for o in others)) | |
dens.append(product(cur - o for o in others)) | |
den = product(dens) | |
num = sum([nums[i] * den * y_s[i] / dens[i] for i in range(k)]) | |
return num / den | |
def recover_secret(shares, field=rationalField): | |
try: | |
x_s, y_s = zip(*shares) | |
except ValueError: | |
x_s, y_s = [], [] | |
return _lagrange_interpolate(field.zero(), x_s, y_s, field) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment