Skip to content

Instantly share code, notes, and snippets.

@AdamISZ
Last active April 3, 2023 19:59
Show Gist options
  • Save AdamISZ/77651979025d16b778494047c86c3a7c to your computer and use it in GitHub Desktop.
Save AdamISZ/77651979025d16b778494047c86c3a7c to your computer and use it in GitHub Desktop.
Demo of logarithmic size ring signature algorithm (Groth and Kohlweiss '14)
#!/usr/bin/env python
help = """
A demonstration of the algorithm of:
Groth and Kohlweiss 2014 "How to leak a secret and spend a coin."
https://eprint.iacr.org/2014/764.pdf
This uses the Joinmarket bitcoin backend, mostly just for its encapsulation
of the package python-bitcointx (`pip install bitcointx` or github:
https://github.com/Simplexum/python-bitcointx).
(though it also uses a couple other helpful functions, so if you do
want to run it then download and run `./install.sh` from:
https://github.com/Joinmarket-Org/joinmarket-clientserver
).
Provide a single argument, which should be a power of 2, like 256, as the
total number of public keys in the ring. The secret key is generated at random
and so are all the decoys, but they are valid secp256k1 keys.
The ring signature is generated and displayed, then verified according to the
algorithm in the paper. Each key is treated as a Pedersen commitment to zero,
i.e. if the key is x * G, then the commitment is (0 * H + x * G).
The signature proves knowledge of the opening to one of the commitments, to zero.
The message being signed is here hardcoded to "hello" since that doesn't matter.
Examples, which you can verify: ring size 256, 1.9kB, ring size 1024, 2.3kB.
"""
import os
import sys
import random
import math
import hashlib
from bitcointx.core.key import CKey
import jmbitcoin as btc
from jmclient.podle import getNUMS
from jmbase import bintohex
groupN = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
infty = "INFTY"
# probably overkill, but just to encapsulate arithmetic in the group;
# this class handles the modular arithmetic of x and +.
class Scalar(object):
def __init__(self, x):
self.x = x % groupN
def to_bytes(self):
return int.to_bytes(self.x, 32, byteorder="big")
@classmethod
def from_bytes(cls, b):
return cls(int.from_bytes(b, byteorder="big"))
@classmethod
def pow(cls, base, exponent):
return cls(pow(base, exponent, groupN))
def __add__(self, other):
if isinstance(other, int):
y = other
elif isinstance(other, Scalar):
y = other.x
return Scalar((self.x + y) % groupN)
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
if isinstance(other, int):
temp = other
elif isinstance(other, Scalar):
temp = other.x
return Scalar((self.x - temp) % groupN)
def __rsub__(self, other):
if isinstance(other, int):
temp = other
elif isinstance(other, Scalar):
temp = other.x
else:
assert False
return Scalar((temp - self.x) % groupN)
def __mul__(self, other):
if other == 1:
return self
if other == 0:
return Scalar(0)
return Scalar((self.x * other.x) % groupN)
def __rmul__(self, other):
return self.__mul__(other)
def __str__(self):
return str(self.x)
def __repr__(self):
return str(self.x)
def binmult(a, b):
""" Given two binary strings,
return their multiplication as a binary string.
"""
# optimisation for pre-mult with bits:
if a == 0:
return b"\x00"*32
if a == 1:
return b
aint = Scalar.from_bytes(a)
bint = Scalar.from_bytes(b)
return (aint * bint).to_bytes()
def pointadd(points):
# NB this is not correct as it does not account for cancellation;
# not sure how a return is serialized as point at infinity in that case.
# (but it doesn't happen in the uses in this module).
pointstoadd = [x for x in points if x != infty]
if len(pointstoadd) == 1:
return pointstoadd[0]
if len(pointstoadd) == 0:
return infty
return btc.add_pubkeys(pointstoadd)
def pointmult(multiplier, point):
# given a scalar 'multiplier' as a binary string,
# and a pubkey 'point', returns multiplier*point
# as another pubkey object
if multiplier == 0:
return infty
if multiplier == 1:
return point
if int.from_bytes(multiplier, byteorder="big") == 0:
return infty
return btc.multiply(multiplier, point, return_serialized=False)
def delta(a, b):
# kronecker delta
return 1 if a==b else 0
def poly_mult_lin(coeffs, a, b):
""" Given a set of polynomial coefficients,
in *decreasing* order of exponent from len(coeffs)-1 to 0,
return the equivalent set of coeffs after multiplication
by ax+b. Note a, b and all the returned elements are type Scalar.
"""
coeffs_new = [Scalar(0) for _ in range(len(coeffs)+1)]
coeffs_new[0] = a * coeffs[0]
for i in range(1, len(coeffs_new)-1):
coeffs_new[i] = b*coeffs[i-1] + a* coeffs[i]
coeffs_new[-1] = b*coeffs[-1]
return coeffs_new
def gen_rand(l=32):
return os.urandom(l)
def gen_privkey_set(n, m):
return (CKey(gen_rand(m), True) for _ in range(n))
# reuse NUMS points code from PoDLE
H = getNUMS(1)
# the actual secp generator
G = btc.getG(True)
def pedersen_commit(message, randomness):
""" Given 2 32 byte binary strings 'message' and
'randomness', returns message * H + randomness * G.
The reason for using 'G' for the second arg is it means
we can use bitcoin pubkeys for the ring signature.
"""
# small optimisation for committing to bits:
# if message is 1 or 0, treat separately:
rG = pointmult(randomness, G)
if message == 0:
return rG
if message == 1:
return pointadd([H, rG])
return pointadd([pointmult(message, H), rG ])
def bit_decomp(n, m):
""" Given an integer n, and a number of bits m,
return the bits of n as encoded in m binary bits, as a list.
"""
return [1 if n & (1 << (m-1-q)) else 0 for q in range(m)]
def hash_transcript(s):
return hashlib.sha256(s).digest()
def get_bits_from_ring(ring):
return math.ceil(math.log(len(ring), 2))
def ring_sign_groth14(seckey, message, decoys):
""" Given a 32 byte secret key and a set of decoy public keys
and a message, this algorithm will:
* calculate the full list of public keys
* arrange the decoys+real randomly and note the index of the real
* create the commitments, form the hash challenge and calculate
the signature data.
"""
real_pub = btc.privkey_to_pubkey(seckey)
ring = decoys + [real_pub]
assert isinstance(ring, list)
random.shuffle(ring)
l = ring.index(real_pub)
print("got this index for the signer: ", l)
bits = get_bits_from_ring(ring)
print("For encoding the indexes of the signers, got this number of bits: ", bits)
lbits = bit_decomp(l, bits)
print("Got these bits for the signer's index: ", lbits)
# now we know the number of bits, we can start
# forming the commitments over each bit:
r = []
a = []
s = []
t = []
rho = []
for j in range(bits):
[q.append(gen_rand()) for q in [r, a, s, t, rho]]
# to form the polynomial coefficients of the polynomials p_i(x),
# for each i-th element of the ring:
polys = []
for i in range(len(ring)):
# first, get the bits i_j of i:
ibits = bit_decomp(i, bits)
# we need to recall:
# we are forming the product pi(f_{j,i_j})
# f_j,1 = delta(1, l_j) * x + a_j
# f_j,0 = delta(0, l_j) * x - a_j
pi_fij = [Scalar(1)]
for j in range(bits):
if ibits[j] == 0:
mult = (Scalar(delta(0, lbits[j])), Scalar.from_bytes(a[j])*Scalar(-1))
else:
mult = (Scalar(delta(1, lbits[j])), Scalar.from_bytes(a[j]))
pi_fij = poly_mult_lin(pi_fij, *mult)
polys.append(pi_fij[::-1]) # reverse order because poly lin mult returns coeffs for x^n-1, x^n-2,...x^0
comm_lj = []
comm_aj = []
comm_ljaj = []
comm_dk = []
for j in range(bits):
comm_lj.append(pedersen_commit(lbits[j], r[j]))
comm_aj.append(pedersen_commit(a[j], s[j]))
if lbits[j] == 0:
comm_ljaj.append(pedersen_commit(0, t[j]))
else:
comm_ljaj.append(pedersen_commit(a[j], t[j]))
# finally add the comm_dk term ('k' is a curiosity of the paper; we can just use index j):
comm_rho = pedersen_commit(0, rho[j])
comms_poly = []
for i in range(len(ring)):
comms_poly.append(pointmult(polys[i][j].to_bytes(), ring[i]))
comm_dk.append(pointadd(comms_poly + [comm_rho]))
# The commitment step is now complete; we move on to the Fiat-Shamir-ised
# second step of the sigma protocol, the challenge:
x = Scalar.from_bytes(hash_transcript(b",".join([str(a).encode() for a in [
comm_lj, comm_aj, comm_ljaj, comm_dk, ring]]) + b"," + message))
print("Got challenge value: ", x)
# Third step of sigma protocol, calculate the responses:
f = []
za = []
zb = []
zd = []
for j in range(bits):
newf = x * lbits[j] + Scalar.from_bytes(a[j])
f.append(newf)
za.append(Scalar.from_bytes(binmult(r[j], x.to_bytes())) + Scalar.from_bytes(s[j]))
zb.append(Scalar.from_bytes(r[j])*(x - newf) + Scalar.from_bytes(t[j]))
zd.append(Scalar.from_bytes(rho[j]) * Scalar.pow(x.x, j))
zdsum = Scalar.from_bytes(seckey) * Scalar.pow(x.x, bits) - sum(zd)
# the signature consists of:
# * commitments to l, a, l*a and dk for every bit, so 4 * nbits * size of a point
# * responses f, za, zb for every bit, so 3 * nbits * size of scalar, plus one extra scalar
# we also return the ring here for convenience, in its jumbled form as it would be presented
# to the verifier
print("Signing is complete.\n")
return ((comm_lj + comm_aj + comm_ljaj + comm_dk + f + za + zb + [zdsum]), ring)
def decompose_signature(sig, bits):
""" Dumb code, but readable anyway.
The hardcoded 7 is because there are 7 lists of items in the sig.
"""
def pull_bits_items(x):
for j in range(bits):
x.append(sig.pop(0))
# python does dumb copy by reference on [[]]*7, cannot use that here:
thing = [[]]
for i in range(6):
thing.append([])
for i in range(7):
pull_bits_items(thing[i])
assert len(sig) == 1 # remaining is zdsum
return thing + sig
def verify_groth14(sig, message, ring):
bits = get_bits_from_ring(ring)
comm_lj, comm_aj, comm_ljaj, comm_dk, f, za, zb, zdsum = decompose_signature(sig, bits)
# recalculate the challenge
x = Scalar.from_bytes(hash_transcript(b",".join([str(a).encode() for a in [comm_lj, comm_aj, comm_ljaj, comm_dk, ring]]) + b"," + message))
print("In verifying, got challenge value: ", x)
for j in range(bits):
if not pointadd([pointmult(x.to_bytes(), comm_lj[j]), comm_aj[j]]) == pedersen_commit(f[j].to_bytes(), za[j].to_bytes()):
print("failed first check at bit: ", j)
return False
if not pointadd([pointmult((x - f[j]).to_bytes(), comm_lj[j]), comm_ljaj[j]]) == pedersen_commit(0, zb[j].to_bytes()):
print("failed seond check at bit: ", j)
return False
fc1 = pointadd([ pointmult((Scalar(-1) * Scalar.pow(x.x, k)).to_bytes(), comm_dk[k]) for k in range(bits)])
fc2s = []
for i in range(len(ring)):
ibits = bit_decomp(i, bits)
exp = Scalar(1)
for j in range(bits):
if ibits[j] == 1:
exp *= f[j]
else:
exp *= x - f[j]
fc2s.append(pointmult(exp.to_bytes(), ring[i]))
fc2 = pointadd(fc2s)
if not pointadd([fc1, fc2]) == pedersen_commit(0, zdsum.to_bytes()):
print("failed third check")
return False
return True
def hexer(x):
if isinstance(x, Scalar):
return bintohex(x.to_bytes())
else:
return bintohex(x)
def test(ringsize):
seckey = gen_rand()
decoys = []
for _ in range(ringsize-1):
decoys.append(btc.privkey_to_pubkey(CKey(gen_rand())))
message = b"hello"
sig, ring = ring_sign_groth14(seckey, message, decoys)
print("Here are the elements of the ring signature:")
print([hexer(x) for x in sig])
print("The signature length in bytes is : ", sum([len(hexer(x)) for x in sig]))
if verify_groth14(sig, message, ring):
print("Test succeeded; the ring signature verifies as valid, for the given message and ring.")
else:
print("Test failed; the ring signature did not verify.")
if __name__ == "__main__":
if not len(sys.argv) == 2:
print(help)
sys.exit(1)
try:
rs = int(sys.argv[1])
assert rs >= 2
assert math.log(rs, 2).is_integer()
except Exception as e:
print(repr(e))
print(help)
sys.exit(1)
test(rs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment