Last active
April 3, 2023 19:59
-
-
Save AdamISZ/77651979025d16b778494047c86c3a7c to your computer and use it in GitHub Desktop.
Demo of logarithmic size ring signature algorithm (Groth and Kohlweiss '14)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
help = """ | |
A demonstration of the algorithm of: | |
Groth and Kohlweiss 2014 "How to leak a secret and spend a coin." | |
https://eprint.iacr.org/2014/764.pdf | |
This uses the Joinmarket bitcoin backend, mostly just for its encapsulation | |
of the package python-bitcointx (`pip install bitcointx` or github: | |
https://github.com/Simplexum/python-bitcointx). | |
(though it also uses a couple other helpful functions, so if you do | |
want to run it then download and run `./install.sh` from: | |
https://github.com/Joinmarket-Org/joinmarket-clientserver | |
). | |
Provide a single argument, which should be a power of 2, like 256, as the | |
total number of public keys in the ring. The secret key is generated at random | |
and so are all the decoys, but they are valid secp256k1 keys. | |
The ring signature is generated and displayed, then verified according to the | |
algorithm in the paper. Each key is treated as a Pedersen commitment to zero, | |
i.e. if the key is x * G, then the commitment is (0 * H + x * G). | |
The signature proves knowledge of the opening to one of the commitments, to zero. | |
The message being signed is here hardcoded to "hello" since that doesn't matter. | |
Examples, which you can verify: ring size 256, 1.9kB, ring size 1024, 2.3kB. | |
""" | |
import os | |
import sys | |
import random | |
import math | |
import hashlib | |
from bitcointx.core.key import CKey | |
import jmbitcoin as btc | |
from jmclient.podle import getNUMS | |
from jmbase import bintohex | |
groupN = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 | |
infty = "INFTY" | |
# probably overkill, but just to encapsulate arithmetic in the group; | |
# this class handles the modular arithmetic of x and +. | |
class Scalar(object): | |
def __init__(self, x): | |
self.x = x % groupN | |
def to_bytes(self): | |
return int.to_bytes(self.x, 32, byteorder="big") | |
@classmethod | |
def from_bytes(cls, b): | |
return cls(int.from_bytes(b, byteorder="big")) | |
@classmethod | |
def pow(cls, base, exponent): | |
return cls(pow(base, exponent, groupN)) | |
def __add__(self, other): | |
if isinstance(other, int): | |
y = other | |
elif isinstance(other, Scalar): | |
y = other.x | |
return Scalar((self.x + y) % groupN) | |
def __radd__(self, other): | |
return self.__add__(other) | |
def __sub__(self, other): | |
if isinstance(other, int): | |
temp = other | |
elif isinstance(other, Scalar): | |
temp = other.x | |
return Scalar((self.x - temp) % groupN) | |
def __rsub__(self, other): | |
if isinstance(other, int): | |
temp = other | |
elif isinstance(other, Scalar): | |
temp = other.x | |
else: | |
assert False | |
return Scalar((temp - self.x) % groupN) | |
def __mul__(self, other): | |
if other == 1: | |
return self | |
if other == 0: | |
return Scalar(0) | |
return Scalar((self.x * other.x) % groupN) | |
def __rmul__(self, other): | |
return self.__mul__(other) | |
def __str__(self): | |
return str(self.x) | |
def __repr__(self): | |
return str(self.x) | |
def binmult(a, b): | |
""" Given two binary strings, | |
return their multiplication as a binary string. | |
""" | |
# optimisation for pre-mult with bits: | |
if a == 0: | |
return b"\x00"*32 | |
if a == 1: | |
return b | |
aint = Scalar.from_bytes(a) | |
bint = Scalar.from_bytes(b) | |
return (aint * bint).to_bytes() | |
def pointadd(points): | |
# NB this is not correct as it does not account for cancellation; | |
# not sure how a return is serialized as point at infinity in that case. | |
# (but it doesn't happen in the uses in this module). | |
pointstoadd = [x for x in points if x != infty] | |
if len(pointstoadd) == 1: | |
return pointstoadd[0] | |
if len(pointstoadd) == 0: | |
return infty | |
return btc.add_pubkeys(pointstoadd) | |
def pointmult(multiplier, point): | |
# given a scalar 'multiplier' as a binary string, | |
# and a pubkey 'point', returns multiplier*point | |
# as another pubkey object | |
if multiplier == 0: | |
return infty | |
if multiplier == 1: | |
return point | |
if int.from_bytes(multiplier, byteorder="big") == 0: | |
return infty | |
return btc.multiply(multiplier, point, return_serialized=False) | |
def delta(a, b): | |
# kronecker delta | |
return 1 if a==b else 0 | |
def poly_mult_lin(coeffs, a, b): | |
""" Given a set of polynomial coefficients, | |
in *decreasing* order of exponent from len(coeffs)-1 to 0, | |
return the equivalent set of coeffs after multiplication | |
by ax+b. Note a, b and all the returned elements are type Scalar. | |
""" | |
coeffs_new = [Scalar(0) for _ in range(len(coeffs)+1)] | |
coeffs_new[0] = a * coeffs[0] | |
for i in range(1, len(coeffs_new)-1): | |
coeffs_new[i] = b*coeffs[i-1] + a* coeffs[i] | |
coeffs_new[-1] = b*coeffs[-1] | |
return coeffs_new | |
def gen_rand(l=32): | |
return os.urandom(l) | |
def gen_privkey_set(n, m): | |
return (CKey(gen_rand(m), True) for _ in range(n)) | |
# reuse NUMS points code from PoDLE | |
H = getNUMS(1) | |
# the actual secp generator | |
G = btc.getG(True) | |
def pedersen_commit(message, randomness): | |
""" Given 2 32 byte binary strings 'message' and | |
'randomness', returns message * H + randomness * G. | |
The reason for using 'G' for the second arg is it means | |
we can use bitcoin pubkeys for the ring signature. | |
""" | |
# small optimisation for committing to bits: | |
# if message is 1 or 0, treat separately: | |
rG = pointmult(randomness, G) | |
if message == 0: | |
return rG | |
if message == 1: | |
return pointadd([H, rG]) | |
return pointadd([pointmult(message, H), rG ]) | |
def bit_decomp(n, m): | |
""" Given an integer n, and a number of bits m, | |
return the bits of n as encoded in m binary bits, as a list. | |
""" | |
return [1 if n & (1 << (m-1-q)) else 0 for q in range(m)] | |
def hash_transcript(s): | |
return hashlib.sha256(s).digest() | |
def get_bits_from_ring(ring): | |
return math.ceil(math.log(len(ring), 2)) | |
def ring_sign_groth14(seckey, message, decoys): | |
""" Given a 32 byte secret key and a set of decoy public keys | |
and a message, this algorithm will: | |
* calculate the full list of public keys | |
* arrange the decoys+real randomly and note the index of the real | |
* create the commitments, form the hash challenge and calculate | |
the signature data. | |
""" | |
real_pub = btc.privkey_to_pubkey(seckey) | |
ring = decoys + [real_pub] | |
assert isinstance(ring, list) | |
random.shuffle(ring) | |
l = ring.index(real_pub) | |
print("got this index for the signer: ", l) | |
bits = get_bits_from_ring(ring) | |
print("For encoding the indexes of the signers, got this number of bits: ", bits) | |
lbits = bit_decomp(l, bits) | |
print("Got these bits for the signer's index: ", lbits) | |
# now we know the number of bits, we can start | |
# forming the commitments over each bit: | |
r = [] | |
a = [] | |
s = [] | |
t = [] | |
rho = [] | |
for j in range(bits): | |
[q.append(gen_rand()) for q in [r, a, s, t, rho]] | |
# to form the polynomial coefficients of the polynomials p_i(x), | |
# for each i-th element of the ring: | |
polys = [] | |
for i in range(len(ring)): | |
# first, get the bits i_j of i: | |
ibits = bit_decomp(i, bits) | |
# we need to recall: | |
# we are forming the product pi(f_{j,i_j}) | |
# f_j,1 = delta(1, l_j) * x + a_j | |
# f_j,0 = delta(0, l_j) * x - a_j | |
pi_fij = [Scalar(1)] | |
for j in range(bits): | |
if ibits[j] == 0: | |
mult = (Scalar(delta(0, lbits[j])), Scalar.from_bytes(a[j])*Scalar(-1)) | |
else: | |
mult = (Scalar(delta(1, lbits[j])), Scalar.from_bytes(a[j])) | |
pi_fij = poly_mult_lin(pi_fij, *mult) | |
polys.append(pi_fij[::-1]) # reverse order because poly lin mult returns coeffs for x^n-1, x^n-2,...x^0 | |
comm_lj = [] | |
comm_aj = [] | |
comm_ljaj = [] | |
comm_dk = [] | |
for j in range(bits): | |
comm_lj.append(pedersen_commit(lbits[j], r[j])) | |
comm_aj.append(pedersen_commit(a[j], s[j])) | |
if lbits[j] == 0: | |
comm_ljaj.append(pedersen_commit(0, t[j])) | |
else: | |
comm_ljaj.append(pedersen_commit(a[j], t[j])) | |
# finally add the comm_dk term ('k' is a curiosity of the paper; we can just use index j): | |
comm_rho = pedersen_commit(0, rho[j]) | |
comms_poly = [] | |
for i in range(len(ring)): | |
comms_poly.append(pointmult(polys[i][j].to_bytes(), ring[i])) | |
comm_dk.append(pointadd(comms_poly + [comm_rho])) | |
# The commitment step is now complete; we move on to the Fiat-Shamir-ised | |
# second step of the sigma protocol, the challenge: | |
x = Scalar.from_bytes(hash_transcript(b",".join([str(a).encode() for a in [ | |
comm_lj, comm_aj, comm_ljaj, comm_dk, ring]]) + b"," + message)) | |
print("Got challenge value: ", x) | |
# Third step of sigma protocol, calculate the responses: | |
f = [] | |
za = [] | |
zb = [] | |
zd = [] | |
for j in range(bits): | |
newf = x * lbits[j] + Scalar.from_bytes(a[j]) | |
f.append(newf) | |
za.append(Scalar.from_bytes(binmult(r[j], x.to_bytes())) + Scalar.from_bytes(s[j])) | |
zb.append(Scalar.from_bytes(r[j])*(x - newf) + Scalar.from_bytes(t[j])) | |
zd.append(Scalar.from_bytes(rho[j]) * Scalar.pow(x.x, j)) | |
zdsum = Scalar.from_bytes(seckey) * Scalar.pow(x.x, bits) - sum(zd) | |
# the signature consists of: | |
# * commitments to l, a, l*a and dk for every bit, so 4 * nbits * size of a point | |
# * responses f, za, zb for every bit, so 3 * nbits * size of scalar, plus one extra scalar | |
# we also return the ring here for convenience, in its jumbled form as it would be presented | |
# to the verifier | |
print("Signing is complete.\n") | |
return ((comm_lj + comm_aj + comm_ljaj + comm_dk + f + za + zb + [zdsum]), ring) | |
def decompose_signature(sig, bits): | |
""" Dumb code, but readable anyway. | |
The hardcoded 7 is because there are 7 lists of items in the sig. | |
""" | |
def pull_bits_items(x): | |
for j in range(bits): | |
x.append(sig.pop(0)) | |
# python does dumb copy by reference on [[]]*7, cannot use that here: | |
thing = [[]] | |
for i in range(6): | |
thing.append([]) | |
for i in range(7): | |
pull_bits_items(thing[i]) | |
assert len(sig) == 1 # remaining is zdsum | |
return thing + sig | |
def verify_groth14(sig, message, ring): | |
bits = get_bits_from_ring(ring) | |
comm_lj, comm_aj, comm_ljaj, comm_dk, f, za, zb, zdsum = decompose_signature(sig, bits) | |
# recalculate the challenge | |
x = Scalar.from_bytes(hash_transcript(b",".join([str(a).encode() for a in [comm_lj, comm_aj, comm_ljaj, comm_dk, ring]]) + b"," + message)) | |
print("In verifying, got challenge value: ", x) | |
for j in range(bits): | |
if not pointadd([pointmult(x.to_bytes(), comm_lj[j]), comm_aj[j]]) == pedersen_commit(f[j].to_bytes(), za[j].to_bytes()): | |
print("failed first check at bit: ", j) | |
return False | |
if not pointadd([pointmult((x - f[j]).to_bytes(), comm_lj[j]), comm_ljaj[j]]) == pedersen_commit(0, zb[j].to_bytes()): | |
print("failed seond check at bit: ", j) | |
return False | |
fc1 = pointadd([ pointmult((Scalar(-1) * Scalar.pow(x.x, k)).to_bytes(), comm_dk[k]) for k in range(bits)]) | |
fc2s = [] | |
for i in range(len(ring)): | |
ibits = bit_decomp(i, bits) | |
exp = Scalar(1) | |
for j in range(bits): | |
if ibits[j] == 1: | |
exp *= f[j] | |
else: | |
exp *= x - f[j] | |
fc2s.append(pointmult(exp.to_bytes(), ring[i])) | |
fc2 = pointadd(fc2s) | |
if not pointadd([fc1, fc2]) == pedersen_commit(0, zdsum.to_bytes()): | |
print("failed third check") | |
return False | |
return True | |
def hexer(x): | |
if isinstance(x, Scalar): | |
return bintohex(x.to_bytes()) | |
else: | |
return bintohex(x) | |
def test(ringsize): | |
seckey = gen_rand() | |
decoys = [] | |
for _ in range(ringsize-1): | |
decoys.append(btc.privkey_to_pubkey(CKey(gen_rand()))) | |
message = b"hello" | |
sig, ring = ring_sign_groth14(seckey, message, decoys) | |
print("Here are the elements of the ring signature:") | |
print([hexer(x) for x in sig]) | |
print("The signature length in bytes is : ", sum([len(hexer(x)) for x in sig])) | |
if verify_groth14(sig, message, ring): | |
print("Test succeeded; the ring signature verifies as valid, for the given message and ring.") | |
else: | |
print("Test failed; the ring signature did not verify.") | |
if __name__ == "__main__": | |
if not len(sys.argv) == 2: | |
print(help) | |
sys.exit(1) | |
try: | |
rs = int(sys.argv[1]) | |
assert rs >= 2 | |
assert math.log(rs, 2).is_integer() | |
except Exception as e: | |
print(repr(e)) | |
print(help) | |
sys.exit(1) | |
test(rs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment