Skip to content

Instantly share code, notes, and snippets.

@Goheeca
Last active March 13, 2019 18:19
Show Gist options
  • Save Goheeca/ae59f472e41e705ba72113a28760831b to your computer and use it in GitHub Desktop.
Save Goheeca/ae59f472e41e705ba72113a28760831b to your computer and use it in GitHub Desktop.
Shamir's Secret Sharing Scheme
from fractions import Fraction
import random
from itertools import zip_longest
class Field(object):
def __init__(self):
self.characteristic = None
def __call__(self, value):
raise NotImplementedError
def zero(self):
raise NotImplementedError
def add(self, a, b):
raise NotImplementedError
def neg(self, a):
raise NotImplementedError
def one(self):
raise NotImplementedError
def mul(self, a, b):
raise NotImplementedError
def inv(self, a):
raise NotImplementedError
def rand(self):
raise NotImplementedError
def value(self, val):
raise NotImplementedError
class FieldElem(object):
def __init__(self, field, value):
self._field = field
self.value = value
def __hash__(self):
return self.value.__hash__()
def __str__(self):
return str(self.value)
def __repr__(self):
return self._field.__repr__() + "<" + self.__str__() + ">"
def _wrap(self, value):
return FieldElem(self._field, value)
def _check(self, other, fn=None):
if isinstance(other, self.__class__):
assert other._field == self._field, "Elements are from different fields."
other = other.value
else:
other = self._field.__call__(other).value
return self.value, (other if fn is None else fn(other))
def __add__(self, other):
return self._wrap(self._field.add(*self._check(other)))
def __neg__(self):
return self._wrap(self._field.neg(self.value))
def __sub__(self, other):
return self._wrap(self._field.add(*self._check(other, self._field.neg)))
def __mul__(self, other):
return self._wrap(self._field.mul(*self._check(other)))
def __invert__(self):
return self._wrap(self._field.inv(self.value))
def __truediv__(self, other):
return self._wrap(self._field.mul(*self._check(other, self._field.inv)))
def __iadd__(self, other):
self.value = self._field.add(*self._check(other))
return self
def __isub__(self, other):
self.value = self._field.add(*self._check(other, self._field.neg))
return self
def __imul__(self, other):
self.value = self._field.mul(*self._check(other))
return self
def __idiv__(self, other):
self.value = self._field.mul(*self._check(other, self._field.inv))
return self
__radd__ = __add__
__rmul__ = __mul__
def __rsub__(self, other):
s, o = self._check(other)
return self._wrap(self._field.add(self._field.neg(s), o))
def __rtruediv__(self, other):
s, o = self._check(other)
return self._wrap(self._field.mul(self._field.inv(s), o))
def __eq__(self, other):
s, o = self._check(other)
return s.__eq__(o)
def __ne__(self, other):
s, o = self._check(other)
return s.__ne__(o)
def __call__(self):
return self._field.value(self.value)
class RationalField(Field):
def __call__(self, value):
return FieldElem(self, Fraction(value))
def zero(self):
return FieldElem(self, Fraction())
def add(self, a, b):
return a + b
def neg(self, a):
return -a
def one(self):
return FieldElem(self, Fraction(1))
def mul(self, a, b):
return a * b
def inv(self, a):
return 1 / a
def rand(self):
return self.__call__(random.gauss(0, 1))
def value(self, val):
return int(val)
def __repr__(self):
return "Q"
rationalField = RationalField()
def xgcd(a, b, zero=lambda: 0, one=lambda: 1, mod=lambda x: x):
"""return (g, x, y) such that a*x + b*y = g = gcd(a, b)"""
x0, x1, y0, y1 = zero(), one(), one(), zero()
z = zero()
while a != z:
q, b, a = mod(b // a), a, mod(b % a)
y0, y1 = y1, mod(y0 - q * y1)
x0, x1 = x1, mod(x0 - q * x1)
return b, x0, y0
def mulinv(a, b, zero=lambda: 0, one=lambda: 1, mod=lambda x: x):
"""return x such that (x * a) % b == 1"""
g, x, _ = xgcd(a, b, zero, one, mod)
o = one()
if g == o:
return x % b
class GaloisPrimeField(Field):
def __init__(self, prime):
self._prime = prime
self.characteristic = prime
def __call__(self, value):
return FieldElem(self, int(value) % self._prime)
def zero(self):
return FieldElem(self, 0)
def add(self, a, b):
return (a + b) % self._prime
def neg(self, a):
return -a % self._prime
def one(self):
return FieldElem(self, 1)
def mul(self, a, b):
return (a * b) % self._prime
def inv(self, a):
return mulinv(a, self._prime)
def rand(self):
return self.__call__(random.randint(0, self.characteristic))
def value(self, val):
return val
def __repr__(self):
return f"GF({self._prime})"
class Polynom(list):
def __init__(self, list_):
if isinstance(list_, list):
super().__init__(list_)
else:
super().__init__([list_])
def __hash__(self):
hash_ = 0
for coeff in self[:self.rank()]:
hash_ ^= hash(coeff)
return hash_
def elem_mod(self, m):
return Polynom([a % m for a in self])
def rank(self, other=None):
if other is None:
other = self
return next((len(other) - idx - 1 for idx, val in enumerate(reversed(other)) if val != 0), 0)
def is_zero(self):
for coeff in self:
if coeff != 0:
return False
return True
def _pol(self, object_):
if isinstance(object_, self.__class__):
return object_
else:
return self.__class__(object_)
def _list(self, object_):
if isinstance(object_, self.__class__):
return list(object_)
else:
return object_
def __add__(self, other):
return Polynom([a + b for a, b in zip_longest(self, self._pol(other), fillvalue=0)])
def __neg__(self):
return Polynom([-a for a in self])
def __sub__(self, other):
return self.__add__(self._pol(other).__neg__())
def __mul__(self, other):
res = Polynom((len(self) + len(other) - 1) * [0])
for i in range(len(self)):
for j in range(len(other)):
res[i+j] += self[i] * other[j]
return res
def __floordiv__(self, other):
other = self._pol(other)
if self.rank() < other.rank():
return Polynom([])
if other.is_zero():
raise ZeroDivisionError
rank_diff = self.rank() - other.rank()
result = Polynom((rank_diff + 1) * [0])
tmp = Polynom(self)
for result_head in range(rank_diff, -1, -1):
coeff = None
head = self.rank() - (rank_diff - result_head)
for k, l in zip(range(head, head - other.rank() - 1, -1), range(other.rank(), -1, -1)):
if coeff is None:
coeff = tmp[k] // other[l]
result[result_head] = coeff
tmp[k] -= coeff * other[l]
return result
def __mod__(self, other):
other = self._pol(other)
if self.rank() < other.rank():
return self
if other.is_zero():
raise ZeroDivisionError
result = Polynom(self)
for result_head in range(self.rank(), other.rank() - 1, -1):
coeff = None
for k, l in zip(range(result_head, result_head - other.rank() - 1, -1), range(other.rank(), -1, -1)):
if coeff is None:
coeff = result[k] // other[l]
result[k] -= coeff * other[l]
return Polynom(result[:result.rank()+1])
def __divmod__(self, other):
# TODO make it effective
return (self.__floordiv__(other), self.__mod__(other))
def __iadd__(self, other):
tmp = self.__add__(other)
self.clear()
self.extend(tmp)
return self
def __isub__(self, other):
tmp = self.__sub__(other)
self.clear()
self.extend(tmp)
return self
def __imul__(self, other):
tmp = self.__mul__(other)
self.clear()
self.extend(tmp)
return self
def __ifloordiv__(self, other):
tmp = self.__floordiv__(other)
self.clear()
self.extend(tmp)
return self
def __imod__(self, other):
tmp = self.__mod__(other)
self.clear()
self.extend(tmp)
return self
__radd__ = __add__
__rmul__ = __mul__
def __rsub__(self, other):
return self.__neg__().__add__(self._pol(other))
def __rfloordiv__(self, other):
return self._pol(other).__floordiv__(self)
def __rmod__(self, other):
return self._pol(other).__mod__(self)
def __rdivmod__(self, other):
return self._pol(other).__divmod__(self)
def __eq__(self, other):
for a, b in zip_longest(self, other, fillvalue=0):
if a != b:
return False
return True
def __ne__(self, other):
return not self.__eq__(other)
class GaloisExtensionField(Field):
def __init__(self, prime, power, polynom):
self._prime = prime
self._power = power
self._polynom = Polynom(polynom)
self.characteristic = prime ** power
def __call__(self, value):
if isinstance(value, int):
value = self._convert(value)
value = Polynom([int(elem) % self._prime for elem in value] + (self._power - len(value)) * [0])
return FieldElem(self, value)
def zero(self):
return FieldElem(self, self._canonical([]))
def add(self, a, b):
return self._canonical(a + b)
def neg(self, a):
return self._canonical(-a)
def one(self):
return FieldElem(self, self._canonical([1]))
def mul(self, a, b):
return self._canonical((a * b) % self._polynom)
def inv(self, a):
return self._canonical(mulinv(a, self._polynom, zero=lambda: self._canonical([]), one=lambda: self._canonical([1]), mod=lambda x: x.elem_mod(self._prime)))
def rand(self):
value = [random.randint(0, self._prime) for i in range(self._power)]
return self.__call__(value)
def value(self, val):
result = 0
for coeff in reversed(val):
result *= self._prime
result += coeff
return result
def _convert(self, idx):
tmp = []
while idx != 0:
tmp.append(idx % self._prime)
idx //= self._prime
return tmp
def _canonical(self, value):
return Polynom(value[:self._power] + (self._power - len(value)) * [0]).elem_mod(self._prime)
def __repr__(self):
return f"GF({self._prime}**{self._power}/{self._polynom})"
import shamir
import finite
import itertools
def powerset(iterable):
s = list(iterable)
return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(len(s)+1))
def shamir_test(m, s, field=finite.rationalField):
secret, shares, poly = shamir._make_random_shares(minimum=m, shares=s, field=field)
print(f"Polynom: {poly}")
print(f"Secret: {secret}")
print(f"Shares: {shares}")
for idxs in powerset(range(s)):
combination = [shares[idx] for idx in idxs]
recovered = shamir.recover_secret(combination, field)
ok = "✓ " if recovered == secret else "✗ "
print(ok + f"Combination: {idxs} Value: {recovered}")
gp521 = 6864797660130609714981900799081393217269435300143305409394463459185543183397656052122559640661454554977296311391480858037121987999716643812574028291115057151
prime30 = 280829369862134719390036617067
gf521 = finite.GaloisPrimeField(gp521)
gf30 = finite.GaloisPrimeField(prime30)
gf256 = finite.GaloisExtensionField(2, 8, [1, 1, 0, 1, 1, 0, 0, 0, 1]) #rijndael
def main():
print("Q (rational numbers):")
shamir_test(3, 4)
print()
print("F_280829369862134719390036617067 (modular numbers):")
shamir_test(3, 4, gf30)
print()
print("F_256 Rijndael (polynomials):")
shamir_test(3, 4, gf256)
if __name__ == '__main__':
main()
from finite import *
def _polynom_value(poly, x, field=rationalField):
value = field.zero()
for coeff in reversed(poly):
value *= x
value += coeff
return value
def _make_random_shares(minimum, shares, field=rationalField):
if minimum > shares:
raise ValueError("pool secret would be irrecoverable")
poly = [field.rand() for i in range(minimum)]
points = [(field(i), _polynom_value(poly, field(i), field))
for i in range(1, shares + 1)]
return poly[0], points, poly
def make_random_shares(minimum, shares, field=rationalField):
secret, shares, poly = _make_random_shares(minimum, shares, field)
return secret, shares
def _lagrange_interpolate(x, x_s, y_s, field=rationalField):
k = len(x_s)
assert k == len(set(x_s)), "points must be distinct"
def product(vals):
value = field.one()
for v in vals:
value *= v
return value
nums = []
dens = []
for i in range(k):
others = list(x_s)
cur = others.pop(i)
nums.append(product(x - o for o in others))
dens.append(product(cur - o for o in others))
den = product(dens)
num = sum([nums[i] * den * y_s[i] / dens[i] for i in range(k)])
return num / den
def recover_secret(shares, field=rationalField):
try:
x_s, y_s = zip(*shares)
except ValueError:
x_s, y_s = [], []
return _lagrange_interpolate(field.zero(), x_s, y_s, field)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment