Add reencryption functionality

pull/263/head
Bogdan Opanchuk 2021-03-17 21:34:12 -07:00
parent d6626ba1a6
commit b96888cafb
6 changed files with 306 additions and 7 deletions

View File

@ -3,9 +3,10 @@ from .__about__ import (
)
from .capsule import Capsule
from .capsule_frag import CapsuleFrag
from .key_frag import KeyFrag, generate_kfrags
from .keys import SecretKey, PublicKey
from .pre import encrypt, decrypt_original
from .pre import encrypt, decrypt_original, decrypt_reencrypted, reencrypt
__all__ = [
"__title__",
@ -20,7 +21,10 @@ __all__ = [
"PublicKey",
"Capsule",
"KeyFrag",
"CapsuleFrag",
"encrypt",
"decrypt_original",
"generate_kfrags",
"reencrypt",
"decrypt_reencrypted",
]

View File

@ -1,12 +1,22 @@
from typing import Tuple
from typing import TYPE_CHECKING, Tuple, Sequence
from .curve_point import CurvePoint
from .curve_scalar import CurveScalar
from .dem import kdf
from .hashing import hash_capsule_points
from .hashing import hash_capsule_points, hash_to_polynomial_arg, hash_to_shared_secret
from .keys import PublicKey, SecretKey
from .params import PARAMETERS
from .serializable import Serializable
if TYPE_CHECKING: # pragma: no cover
from .capsule_frag import CapsuleFrag
def lambda_coeff(xs: Sequence[CurveScalar], i: int) -> CurveScalar:
res = CurveScalar.one()
for j in range(len(xs)):
if j != i:
inv_diff = (xs[j] - xs[i]).invert()
res = (res * xs[j]) * inv_diff
return res
class Capsule(Serializable):
@ -54,6 +64,54 @@ class Capsule(Serializable):
def open_original(self, sk: SecretKey) -> CurvePoint:
return (self.point_e + self.point_v) * sk.secret_scalar()
def open_reencrypted(self,
receiving_sk: SecretKey,
delegating_pk: PublicKey,
cfrags: Sequence['CapsuleFrag'],
) -> CurvePoint:
if len(cfrags) == 0:
raise ValueError("Empty CapsuleFrag sequence")
precursor = cfrags[0].precursor
if len(set(cfrags)) != len(cfrags):
raise ValueError("Some of the CapsuleFrags are repeated")
if not all(cfrag.precursor == precursor for cfrag in cfrags[1:]):
raise ValueError("CapsuleFrags are not pairwise consistent")
pub_key = PublicKey.from_secret_key(receiving_sk).point()
dh_point = precursor * receiving_sk.secret_scalar()
# Combination of CFrags via Shamir's Secret Sharing reconstruction
lc = [hash_to_polynomial_arg(precursor, pub_key, dh_point, cfrag.kfrag_id)
for cfrag in cfrags]
e_primes = []
v_primes = []
for i, cfrag in enumerate(cfrags):
lambda_i = lambda_coeff(lc, i)
e_primes.append(cfrag.point_e1 * lambda_i)
v_primes.append(cfrag.point_v1 * lambda_i)
e_prime = sum(e_primes[1:], e_primes[0])
v_prime = sum(v_primes[1:], v_primes[0])
# Secret value 'd' allows to make Umbral non-interactive
d = hash_to_shared_secret(precursor, pub_key, dh_point)
s = self.signature
h = hash_capsule_points(self.point_e, self.point_v)
orig_pub_key = delegating_pk.point()
# TODO: check for d == 0? Or just let if fail?
inv_d = d.invert()
if orig_pub_key * (s * inv_d) != (e_prime * h) + v_prime:
raise ValueError("Internal validation failed")
return (e_prime + v_prime) * d
def _components(self):
return (self.point_e, self.point_v, self.signature)

200
umbral/capsule_frag.py Normal file
View File

@ -0,0 +1,200 @@
from typing import Sequence, Optional
from .capsule import Capsule
from .curve_point import CurvePoint
from .curve_scalar import CurveScalar
from .hashing import Hash, hash_to_cfrag_verification, hash_to_cfrag_signature
from .keys import PublicKey, SecretKey, Signature
from .key_frag import KeyFrag, KeyFragID
from .params import PARAMETERS
from .serializable import Serializable
class CapsuleFragProof(Serializable):
def __init__(self,
point_e2: CurvePoint,
point_v2: CurvePoint,
kfrag_commitment: CurvePoint,
kfrag_pok: CurvePoint,
signature: CurveScalar,
kfrag_signature: Signature,
):
self.point_e2 = point_e2
self.point_v2 = point_v2
self.kfrag_commitment = kfrag_commitment
self.kfrag_pok = kfrag_pok
self.signature = signature
self.kfrag_signature = kfrag_signature
def _components(self):
return (self.point_e2, self.point_v2, self.kfrag_commitment,
self.kfrag_pok, self.signature, self.kfrag_signature)
def __eq__(self, other):
return self._components() == other._components()
@classmethod
def __take__(cls, data):
types = [CurvePoint, CurvePoint, CurvePoint, CurvePoint, CurveScalar, Signature]
components, data = cls.__take_types__(data, *types)
return cls(*components), data
def __bytes__(self):
return (bytes(self.point_e2) +
bytes(self.point_v2) +
bytes(self.kfrag_commitment) +
bytes(self.kfrag_pok) +
bytes(self.signature) +
bytes(self.kfrag_signature)
)
@classmethod
def from_kfrag_and_cfrag(cls,
capsule: Capsule,
kfrag: KeyFrag,
cfrag_e1: CurvePoint,
cfrag_v1: CurvePoint,
metadata: Optional[bytes],
) -> 'CapsuleFragProof':
params = PARAMETERS
rk = kfrag.key
t = CurveScalar.random_nonzero()
# Here are the formulaic constituents shared with `CapsuleFrag.verify()`.
e = capsule.point_e
v = capsule.point_v
e1 = cfrag_e1
v1 = cfrag_v1
u = params.u
u1 = kfrag.proof.commitment
e2 = e * t
v2 = v * t
u2 = u * t
h = hash_to_cfrag_verification([e, e1, e2, v, v1, v2, u, u1, u2], metadata)
###
z3 = t + rk * h
return cls(point_e2=e2,
point_v2=v2,
kfrag_commitment=u1,
kfrag_pok=u2,
signature=z3,
kfrag_signature=kfrag.proof.signature_for_receiver,
)
class CapsuleFrag(Serializable):
def __init__(self,
point_e1: CurvePoint,
point_v1: CurvePoint,
kfrag_id: KeyFragID,
precursor: CurvePoint,
proof: CapsuleFragProof,
):
self.point_e1 = point_e1
self.point_v1 = point_v1
self.kfrag_id = kfrag_id
self.precursor = precursor
self.proof = proof
def _components(self):
return (self.point_e1, self.point_v1, self.kfrag_id, self.precursor, self.proof)
def __eq__(self, other):
return self._components() == other._components()
def __hash__(self):
return hash((self.__class__, bytes(self)))
def __str__(self):
return f"{self.__class__.__name__}:{bytes(self).hex()[:16]}"
@classmethod
def __take__(cls, data):
types = CurvePoint, CurvePoint, KeyFragID, CurvePoint, CapsuleFragProof
components, data = cls.__take_types__(data, *types)
return cls(*components), data
def __bytes__(self):
return (bytes(self.point_e1) +
bytes(self.point_v1) +
bytes(self.kfrag_id) +
bytes(self.precursor) +
bytes(self.proof))
@classmethod
def reencrypted(cls,
capsule: Capsule,
kfrag: KeyFrag,
metadata: Optional[bytes] = None,
) -> 'CapsuleFrag':
rk = kfrag.key
e1 = capsule.point_e * rk
v1 = capsule.point_v * rk
proof = CapsuleFragProof.from_kfrag_and_cfrag(capsule, kfrag, e1, v1, metadata)
return cls(point_e1=e1,
point_v1=v1,
kfrag_id=kfrag.id,
precursor=kfrag.precursor,
proof=proof,
)
def verify(self,
capsule: Capsule,
delegating_pk: PublicKey,
receiving_pk: PublicKey,
signing_pk: PublicKey,
metadata: Optional[bytes] = None,
) -> bool:
params = PARAMETERS
# Here are the formulaic constituents shared with
# `CapsuleFragProof.from_kfrag_and_cfrag`.
e = capsule.point_e
v = capsule.point_v
e1 = self.point_e1
v1 = self.point_v1
u = params.u
u1 = self.proof.kfrag_commitment
e2 = self.proof.point_e2
v2 = self.proof.point_v2
u2 = self.proof.kfrag_pok
h = hash_to_cfrag_verification([e, e1, e2, v, v1, v2, u, u1, u2], metadata)
###
precursor = self.precursor
kfrag_id = self.kfrag_id
kfrag_signature = hash_to_cfrag_signature(kfrag_id, u1, precursor, delegating_pk, receiving_pk)
valid_kfrag_signature = kfrag_signature.verify(signing_pk, self.proof.kfrag_signature)
z3 = self.proof.signature
correct_reencryption_of_e = e * z3 == e2 + e1 * h
correct_reencryption_of_v = v * z3 == v2 + v1 * h
correct_rk_commitment = u * z3 == u2 + u1 * h
return (valid_kfrag_signature
and correct_reencryption_of_e
and correct_reencryption_of_v
and correct_rk_commitment)

View File

@ -27,7 +27,7 @@ class CurveScalar(Serializable):
"""
Returns a CurveScalar object with a cryptographically secure OpenSSL BIGNUM.
"""
one = backend._lib.BN_value_one()
one = cls.one()._backend_bignum
# TODO: in most cases, we want this number to be secret.
# OpenSSL 1.1.1 has `BN_priv_rand_range()`, but it is not
@ -89,6 +89,10 @@ class CurveScalar(Serializable):
# -1 less than, 0 is equal to, 1 is greater than
return not bool(backend._lib.BN_cmp(self._backend_bignum, other._backend_bignum))
@classmethod
def one(cls):
return cls(backend._lib.BN_value_one())
def is_zero(self):
# BN_is_zero() is not exported, so this will have to do
return self == 0

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional, Type
from typing import TYPE_CHECKING, Optional, Type, Iterable
from cryptography.hazmat.backends.openssl import backend
from cryptography.hazmat.primitives import hashes
@ -15,6 +15,9 @@ if TYPE_CHECKING: # pragma: no cover
class Hash:
OUTPUT_SIZE = 32
def __init__(self, dst: bytes):
self._sha256 = hashes.Hash(hashes.SHA256(), backend=backend)
len_dst = len(dst).to_bytes(4, byteorder='big')
@ -73,6 +76,16 @@ def hash_to_shared_secret(precursor: CurvePoint,
return digest_to_scalar(digest)
def hash_to_cfrag_verification(points: Iterable[CurvePoint], metadata: Optional[bytes] = None) -> CurveScalar:
digest = Hash(b"CFRAG_VERIFICATION")
for point in points:
digest.update(bytes(point))
if metadata is not None:
digest.update(metadata)
return digest_to_scalar(digest)
def hash_to_cfrag_signature(kfrag_id: 'KeyFragID',
commitment: CurvePoint,
precursor: CurvePoint,

View File

@ -1,8 +1,10 @@
from typing import Tuple
from typing import Tuple, Optional, Sequence
from .capsule import Capsule
from .capsule_frag import CapsuleFrag
from .dem import DEM
from .keys import PublicKey, SecretKey
from .key_frag import KeyFrag
def encrypt(pk: PublicKey, plaintext: bytes) -> Tuple[Capsule, bytes]:
@ -27,3 +29,21 @@ def decrypt_original(sk: SecretKey, capsule: Capsule, ciphertext: bytes) -> byte
key_seed = capsule.open_original(sk)
dem = DEM(bytes(key_seed))
return dem.decrypt(ciphertext, authenticated_data=bytes(capsule))
def reencrypt(capsule: Capsule, kfrag: KeyFrag, metadata: Optional[bytes] = None) -> CapsuleFrag:
return CapsuleFrag.reencrypted(capsule, kfrag, metadata)
def decrypt_reencrypted(decrypting_sk: SecretKey,
delegating_pk: PublicKey,
capsule: Capsule,
cfrags: Sequence[CapsuleFrag],
ciphertext: bytes,
) -> bytes:
key_seed = capsule.open_reencrypted(decrypting_sk, delegating_pk, cfrags)
# TODO: add salt and info here?
dem = DEM(bytes(key_seed))
return dem.decrypt(ciphertext, authenticated_data=bytes(capsule))