diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..cfd050c --- /dev/null +++ b/Makefile @@ -0,0 +1,25 @@ +PYTHON ?= PYTHONPATH=. python3 +SOLC ?= ./solc-static-linux +SOLC_VERSION = 0.5.8 +SOLC_OPTS ?= --optimize + +all: + @echo "Read Makefile?" + +clean: + rm -rf build .coverage .coverage.* + find . -name '*.pyc' -exec rm '{}' ';' + find . -name '__pycache__' -exec rm '{}' ';' + +.PHONY: test +test: + cd contracts && PYTHONPATH=.. python3 -m unittest discover ../tests/ + +requirements: + $(PYTHON) -mpip install -r requirements.txt + + +# Retrieve static built solidity compiler for Linux (useful...) +solc-static-linux: + curl -L -o $@ "https://github.com/ethereum/solidity/releases/download/v$(SOLC_VERSION)/solc-static-linux" || rm -f $@ + chmod 755 $@ diff --git a/bls.py b/bls.py new file mode 100644 index 0000000..cb962df --- /dev/null +++ b/bls.py @@ -0,0 +1,125 @@ +from functools import reduce +import binascii +from os import urandom +from py_ecc.bn128 import * +from sha3 import keccak_256 + +""" +Implements BLS signatture aggregation as described at: + + https://crypto.stanford.edu/~dabo/pubs/papers/BLSmultisig.html + +--------- + +Roughly based on the following code: + + >>> from py_ecc.bn128 import * + >>> from random import randint + + >>> sk1, sk2 = randint(1,curve_order-1), randint(1, curve_ + ... order-1) + + >>> pk1, pk2 = multiply(G2, sk1), multiply(G2, sk2) + + >>> H = multiply(G1, randint(1, curve_order-1)) + + >>> sig1, sig2 = multiply(H, sk1), multiply(H, sk2) + + >>> aggpk = add(pk1, pk2) + + >>> aggsig = add(sig1, sig2) + + >>> pairing(aggpk, H) == pairing(G2, aggsig) + True + + +""" + +addmodp = lambda x, y: (x + y) % field_modulus + +mulmodp = lambda x, y: (x * y) % field_modulus + +safe_ord = lambda x: x if isinstance(x, int) else ord(x) + +bytes_to_int = lambda x: reduce(lambda o, b: (o << 8) + safe_ord(b), [0] + list(x)) + +def int_to_big_endian(lnum): + if lnum == 0: + return b'\0' + s = hex(lnum)[2:].rstrip('L') + if len(s) & 1: + s = '0' + s + return binascii.unhexlify(s) + +zpad = lambda x, l: b'\x00' * max(0, l - len(x)) + x + +tobe256 = lambda v: zpad(int_to_big_endian(v), 32) + +def g2_to_list(point): + return [_.n for _ in point[0].coeffs + point[1].coeffs] + +def g1_to_list(point): + return [_.n for _ in point] + +fmt_list = lambda point: '[' + ', '.join([('"' + hex(_) + '"') for _ in point]) + ']' + +def randn(): + return int.from_bytes(urandom(64), 'big') % curve_order + +def hashs(*x): + data = b''.join(map(tobe256, x)) + return bytes_to_int(keccak_256(data).digest()) + +def evalcurve_g1(x): + beta = addmodp(mulmodp(mulmodp(x, x), x), 3) + assert field_modulus % 4 == 3 + a = (field_modulus+1)//4 # fast square root, using exponentation, assuming (p%4 == 3) + y = pow(beta, a, field_modulus) + return (beta, y) + +def isoncurve_g1(x, y): + return mulmodp(y, y) == addmodp(mulmodp(mulmodp(x, x), x), 3) + +def hash_to_g1(x): + # XXX: todo, re-hash on every round + assert isinstance(x, int) + x = x % field_modulus + while True: + beta, y = evalcurve_g1(x) + if beta == mulmodp(y, y): + assert isoncurve_g1(x, y) + return FQ(x), FQ(y) + x = addmodp(x, 1) + +def hash_g2(x): + xy_ints = [_.n for _ in (x[0].coeffs + x[1].coeffs)] + return hashs(*xy_ints) + +def bls_keygen(): + sk = randn() + pk = multiply(G2, sk) + return sk, pk + +def bls_prove_key(sk): + pk = multiply(G2, sk) + msg = hash_g2(pk) + return bls_sign(sk, msg) + +def bls_verify_key(pk, sig): + msg = hash_g2(pk) + return bls_verify(pk, msg, sig) + +def bls_sign(sk, msg): + H_m = hash_to_g1(msg) + sig = multiply(H_m, sk) + return sig + +def bls_verify(pk, msg, sig) -> bool: + H_m = hash_to_g1(msg) + return pairing(pk, H_m) * pairing(neg(G2), sig) == FQ12.one() + +def bls_agg_verify(pks, msg, sigs) -> bool: + H_m = hash_to_g1(msg) + agg_pk = reduce(add, pks) + agg_sig = reduce(add, sigs) + return pairing(agg_pk, H_m) * pairing(neg(G2), agg_sig) == FQ12.one() diff --git a/contracts/BLSValidators.sol b/contracts/BLSValidators.sol index 38bd130..ff5755b 100644 --- a/contracts/BLSValidators.sol +++ b/contracts/BLSValidators.sol @@ -1,223 +1,215 @@ -pragma solidity ^0.4.24; +pragma solidity ^0.5.8; -import { BN256G2 } from "./BN256G2.sol"; - -/* -Toy working POC on BLS Sig and aggregation in Ethereum. - -Signatures are generated using https://github.com/0xAshish/py_ecc/blob/master/tests/BLSsmall.py -Code is based on https://github.com/jstoxrocky/zksnarks_example - -*/ +import { BN256G2 } from "./BN256G2.sol"; -contract BLSValidators { - struct G1Point { - uint X; - uint Y; - } +contract BLSValidators +{ + uint256 internal constant FIELD_ORDER = 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47; - // Encoding of field elements is: X[0] * z + X[1] - struct G2Point { - uint[2] X; - uint[2] Y; - } + // a = (FIELD_ORDER+1) // 4 + uint256 internal constant CURVE_A = 0xc19139cb84c680a6e14116da060561765e05aa45a1c72a34f082305b61f3f52; struct Validator { - address user; + address owner; uint256 amount; - G1Point pubkey; + uint256[4] pubkey; } - uint256 public vCount = 0; - mapping (uint256 => Validator) public validators; + uint256 internal aggregate_bitmask; + uint256[4] internal aggregate_pubkey; - event newValidator(uint256 indexed validatorId); + mapping (uint8 => Validator) internal validators; - function addValidator(uint256 pkX, uint256 pkY, uint256 amount) public { - vCount++; - validators[vCount] = Validator(msg.sender, amount, G1Point(pkX, pkY)); - emit newValidator(vCount); - } + event OnNewValidator(uint8 index, address owner, uint256[4] pk); - function addValidatorTest(uint256 amount, uint256 _pk,uint256 n) public { - for(uint256 i = 0 ;i < n; i++) { - vCount++; - // Temporary - G1Point memory pk = mul(P1(), _pk+i); - validators[vCount] = Validator(msg.sender,amount, pk); - } - } + event OnValidatorRemoved(uint8 index); - function getValidatorDetails(uint256 id) public view - returns( - address, - uint256, - uint256, - uint256 - ) { - return (validators[id].user, validators[id].amount, validators[id].pubkey.X, validators[id].pubkey.Y); + constructor () public { + aggregate_pubkey = [uint256(0), uint256(0), uint256(0), uint256(0)]; } - function checkSigAGG(uint256 bitmask, uint256 sigs0, uint256 sigs1, uint256 sigs2, uint256 sigs3, uint256 message) public returns(bool) { - G1Point memory pubkey; - for(uint256 i = 0; i < vCount; i++) { - // if((bitmask >> i) & 1 > 0) { - Validator v = validators[i+1]; - pubkey = add(pubkey, v.pubkey); - // } + function HashToG1(uint256 s) + internal view returns (uint256[2] memory) + { + uint256 beta = 0; + uint256 y = 0; + uint256 x = s % FIELD_ORDER; + while( true ) { + (beta, y) = FindYforX(x); + if(beta == mulmod(y, y, FIELD_ORDER)) { + return [x, y]; + } + x = addmod(x, 1, FIELD_ORDER); } - - G2Point memory H = hashToG2(message); - G2Point memory signature = G2Point([sigs1,sigs0],[sigs3,sigs2]); - return pairing2(P1(), H, negate(pubkey), signature); } - - function testCheckSigAGG() public { - - G1Point memory pubkey = G1Point( - 17380323886581056473092238415087178747833394266216426706118377188344506669132, - 8264330258127714892906603723635360533223500611780692134587255146148491007336); - // hash on G2 point - G2Point memory H = G2Point( - [7806540115951598708068323537226325143489341620121102987168061034219723055482, - 16102053849180588443131133900438094849149715436625045469236991987039241848240], - [6718946360417026759307173704450430250787528919693688413464546568151449945362, - 15085587210032391178752839157819905008772577581989468040951987143794090031385]); - - G2Point memory signature = G2Point( - [20510297253563043906240734487189027213933976667621835319448331165769997484335, - 17039283792713629953217756598150981109636679343767085841835508695942368202923], - [1985362097212581787757922254110217851026070065076532109495179805548055991837, - 7135647869386222135872517926452623520408611489591663660104271578165118400268]); - - require(pairing2(P1(), H, negate(pubkey), signature), "Something went wrong"); + function FindYforX(uint256 x) + internal view returns (uint256, uint256) + { + // beta = (x^3 + b) % p + uint256 beta = addmod(mulmod(mulmod(x, x, FIELD_ORDER), x, FIELD_ORDER), 3, FIELD_ORDER); + uint256 y = modPow(beta, CURVE_A, FIELD_ORDER); + return (beta, y); + } + + function AddValidator(uint8 index, uint256[4] memory pk, uint256[2] memory sig) + public payable + { + require( msg.value != 0 ); + require( validators[index].owner == address(0) ); + require( ProvePublicKey(pk, sig) ); + + validators[index] = Validator(msg.sender, msg.value, pk); + + // To handle the special case where all validators agree on something + // We pre-accumulate the keys to avoid doing it every time a signature is validated + // Maintain a bitmask of their indices + uint256[4] memory p; + (p[0], p[1], p[2], p[3]) = BN256G2.ECTwistAdd(aggregate_pubkey[0], aggregate_pubkey[1], + aggregate_pubkey[2], aggregate_pubkey[3], + pk[0], pk[1], pk[2], pk[3]); + aggregate_pubkey = p; + aggregate_bitmask = aggregate_bitmask & (uint256(1)<> i) & 1 > 0 ) + { + require( validators[i].owner != address(0) ); + uint256[4] memory p = validators[i].pubkey; + (ap[0], ap[1], ap[2], ap[3]) = BN256G2.ECTwistAdd(ap[0], ap[1], ap[2], ap[3], + p[0], p[1], p[2], p[3]); + } + } + } + return CheckSignature(message, ap, sig); } - /// @return the generator of G1 - function P1() internal returns (G1Point) { - return G1Point(1, 2); + /// @return the generator of G2 + function G2() + internal pure returns (uint256[4] memory) + { + return [0x1800deef121f1e76426a00665e5c4479674322d4f75edadd46debd5cd992f6ed, + 0x198e9393920d483a7260bfb731fb5d25f1aa493335a9e71297e485b7aef312c2, + 0x12c85ea5db8c6deb4aab71808dcb408fe3d1e7690c43d37b4ce6cc0166fa7daa, + 0x90689d0585ff075ec9e99ad690c3395bc4b313370b38ef355acdadcd122975b]; } - /// @return the generator of G2 - function P2() internal returns (G2Point) { - return G2Point( - [11559732032986387107991004021392285783925812861821192530917403151452391805634, - 10857046999023057135944570762232829481370756359578518086990519993285655852781], - - [4082367875863433681332203403145435568316851327593401208105741076214120093531, - 8495653923123431417604973247489272438418190587263600148770280649306958101930] - ); - } - - /// @return the result of computing the pairing check - /// e(p1[0], p2[0]) * .... * e(p1[n], p2[n]) == 1 - /// For example pairing([P1(), P1().negate()], [P2(), P2()]) should - /// return true. - function pairing(G1Point[] memory p1, G2Point[] memory p2) internal returns (bool) { - require(p1.length == p2.length); - uint elements = p1.length; - uint inputSize = elements * 6; - uint[] memory input = new uint[](inputSize); - - for (uint i = 0; i < elements; i++) - { - input[i * 6 + 0] = p1[i].X; - input[i * 6 + 1] = p1[i].Y; - input[i * 6 + 2] = p2[i].X[0]; - input[i * 6 + 3] = p2[i].X[1]; - input[i * 6 + 4] = p2[i].Y[0]; - input[i * 6 + 5] = p2[i].Y[1]; - } + function NegativeG2() + internal pure returns (uint256[4] memory) + { + return [0x1800deef121f1e76426a00665e5c4479674322d4f75edadd46debd5cd992f6ed, + 0x198e9393920d483a7260bfb731fb5d25f1aa493335a9e71297e485b7aef312c2, + 0x1d9befcd05a5323e6da4d435f3b617cdb3af83285c2df711ef39c01571827f9d, + 0x275dc4a288d1afb3cbb1ac09187524c7db36395df7be3b99e673b13a075a65ec]; + } + /// Convenience method for a pairing check for two pairs. + function pairing2(uint256[2] memory a1, uint256[4] memory a2, uint256[2] memory b1, uint256[4] memory b2) + internal view returns (bool) + { + uint256[12] memory input = [ + a1[0], a1[1], // a1 (G1) + a2[1], a2[0], a2[3], a2[2], // a2 (G2) + + b1[0], b1[1], // b1 (G1) + b2[1], b2[0], b2[3], b2[2] // b2 (G2) + ]; uint[1] memory out; - bool success; - assembly { - success := call(sub(gas, 2000), 8, 0, add(input, 0x20), mul(inputSize, 0x20), out, 0x20) - // Use "invalid" to make gas estimation work - switch success case 0 {invalid} + if iszero(staticcall(sub(gas, 2000), 8, input, 0x180, out, 0x20)) { + revert(0, 0) + } } - require(success); return out[0] != 0; } - /// Convenience method for a pairing check for two pairs. - function pairing2(G1Point a1, G2Point a2, G1Point b1, G2Point b2) internal returns (bool) { - G1Point[] memory p1 = new G1Point[](2); - G2Point[] memory p2 = new G2Point[](2); - p1[0] = a1; - p1[1] = b1; - p2[0] = a2; - p2[1] = b2; - return pairing(p1, p2); - } - - function hashToG1(uint256 h) internal returns (G1Point) { - return mul(P1(), h); - } - - function hashToG2(uint256 h) internal returns (G2Point memory) { - G2Point memory p2 = P2(); - uint256 x1; - uint256 x2; - uint256 y1; - uint256 y2; - (x1,x2,y1,y2) = BN256G2.ECTwistMul(h, p2.X[1], p2.X[0], p2.Y[1], p2.Y[0]); - return G2Point([x2,x1],[y2,y1]); - } - - function modPow(uint256 base, uint256 exponent, uint256 modulus) internal returns (uint256) { + function modPow(uint256 base, uint256 exponent, uint256 modulus) + internal view returns (uint256) + { uint256[6] memory input = [32, 32, 32, base, exponent, modulus]; uint256[1] memory result; assembly { - if iszero(call(not(0), 0x05, 0, input, 0xc0, result, 0x20)) { + if iszero(staticcall(not(0), 0x05, input, 0xc0, result, 0x20)) { revert(0, 0) } } return result[0]; } - /// @return the negation of p, i.e. p.add(p.negate()) should be zero. - function negate(G1Point p) internal returns (G1Point) { - // The prime q in the base field F_q for G1 - uint q = 21888242871839275222246405745257275088696311157297823662689037894645226208583; - if (p.X == 0 && p.Y == 0) - return G1Point(0, 0); - return G1Point(p.X, q - (p.Y % q)); - } - - /// @return the sum of two points of G1 - function add(G1Point p1, G1Point p2) internal returns (G1Point r) { - uint[4] memory input; - input[0] = p1.X; - input[1] = p1.Y; - input[2] = p2.X; - input[3] = p2.Y; - bool success; - assembly { - success := call(sub(gas, 2000), 6, 0, input, 0xc0, r, 0x60) - // Use "invalid" to make gas estimation work - switch success case 0 {invalid} - } - require(success); + function negate(uint256 value) + internal pure returns (uint256) + { + return FIELD_ORDER - (value % FIELD_ORDER); } - /// @return the product of a point on G1 and a scalar, i.e. - /// p == p.mul(1) and p.add(p) == p.mul(2) for all points p. - function mul(G1Point p, uint s) internal returns (G1Point r) { - uint[3] memory input; - input[0] = p.X; - input[1] = p.Y; - input[2] = s; - bool success; - assembly { - success := call(sub(gas, 2000), 7, 0, input, 0x80, r, 0x60) - // Use "invalid" to make gas estimation work - switch success case 0 {invalid} - } - require(success); + function negate(uint256[4] memory p) + internal pure returns (uint256[4] memory) + { + return [p[0], p[1], negate(p[2]), negate(p[3])]; } } diff --git a/contracts/BN256G2.sol b/contracts/BN256G2.sol index e74db01..d49f7e9 100644 --- a/contracts/BN256G2.sol +++ b/contracts/BN256G2.sol @@ -1,4 +1,4 @@ -pragma solidity ^0.4.24; +pragma solidity ^0.5.8; /** * @title Elliptic curve operations on twist points for alt_bn128 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7c86a67 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +ethereum +py_ecc +pysha3 diff --git a/tests/test_BLSValidators.py b/tests/test_BLSValidators.py new file mode 100644 index 0000000..0a72540 --- /dev/null +++ b/tests/test_BLSValidators.py @@ -0,0 +1,32 @@ +import unittest +from ethereum.tools import tester +from binascii import hexlify + +import bls + + +class TestBLSValidators(unittest.TestCase): + def setUp(self): + env = tester.get_env(None) + env.config['BLOCK_GAS_LIMIT'] = 5**10 + chain = tester.Chain(env=env) + + with open('BN256G2.sol') as handle: + source = handle.read() + BN256G2 = chain.contract(source, language='solidity') + with open('BLSValidators.sol') as handle: + source = handle.read() + self.contract = chain.contract(source, libraries={'BN256G2': hexlify(BN256G2.address)}, language='solidity') + + def test_AddValidator(self): + index = 0 + sk, pk = bls.bls_keygen() + sig = bls.bls_prove_key(sk) + result = self.contract.AddValidator(0, bls.g2_to_list(pk), bls.g1_to_list(sig), value=1000) + print(result) + result2 = self.contract.GetValidator(0) + print(result2) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_bls.py b/tests/test_bls.py new file mode 100644 index 0000000..23aa7a5 --- /dev/null +++ b/tests/test_bls.py @@ -0,0 +1,32 @@ +import unittest +import bls + + +class BlsTests(unittest.TestCase): + def test_prove_key(self): + sk, pk = bls.bls_keygen() + sig = bls.bls_prove_key(sk) + self.assertTrue(bls.bls_verify_key(pk, sig)) + + def test_aggregate(self): + "Aggregate signature where the two participants sign the same message" + sk1, pk1 = bls.bls_keygen() + sk2, pk2 = bls.bls_keygen() + msg = bls.randn() + sig1 = bls.bls_sign(sk1, msg) + sig2 = bls.bls_sign(sk2, msg) + self.assertTrue(bls.bls_agg_verify([pk1, pk2], msg, [sig1, sig2])) + + def test_aggregate_bad(self): + "Bad aggregate signature, where the two participants sign different messages" + sk1, pk1 = bls.bls_keygen() + sk2, pk2 = bls.bls_keygen() + msg1 = bls.randn() + msg2 = bls.randn() + sig1 = bls.bls_sign(sk1, msg1) + sig2 = bls.bls_sign(sk2, msg2) + self.assertFalse(bls.bls_agg_verify([pk1, pk2], msg1, [sig1, sig2])) + + +if __name__ == "__main__": + unittest.main()