Add Capsule class and encrypt()/decrypt_original()

pull/263/head
Bogdan Opanchuk 2021-03-16 16:41:05 -07:00
parent f33431d92a
commit 2c28ae8bc2
6 changed files with 137 additions and 10 deletions

View File

@ -2,7 +2,9 @@ from .__about__ import (
__author__, __license__, __summary__, __title__, __version__, __copyright__, __email__, __url__
)
from .capsule import Capsule
from .keys import SecretKey, PublicKey
from .pre import encrypt, decrypt_original
__all__ = [
"__title__",
@ -15,4 +17,7 @@ __all__ = [
"__url__",
"SecretKey",
"PublicKey",
"Capsule",
"encrypt",
"decrypt_original",
]

73
umbral/capsule.py Normal file
View File

@ -0,0 +1,73 @@
from typing import Tuple
from .curve_point import CurvePoint
from .curve_scalar import CurveScalar
from .dem import kdf
from .hashing import hash_capsule_points
from .keys import PublicKey, SecretKey
from .params import PARAMETERS
from .serializable import Serializable
class Capsule(Serializable):
class NotValid(ValueError):
"""
raised if the capsule does not pass verification.
"""
def __init__(self, point_e: CurvePoint, point_v: CurvePoint, signature: CurveScalar):
self.point_e = point_e
self.point_v = point_v
self.signature = signature
@classmethod
def __take__(cls, data: bytes) -> Tuple['Capsule', bytes]:
(e, v, sig), data = cls.__take_types__(data, CurvePoint, CurvePoint, CurveScalar)
capsule = cls(e, v, sig)
if not capsule._verify():
raise cls.NotValid("Capsule verification failed.")
return capsule, data
def __bytes__(self) -> bytes:
return bytes(self.point_e) + bytes(self.point_v) + bytes(self.signature)
@classmethod
def from_public_key(cls, pk: PublicKey) -> Tuple['Capsule', CurvePoint]:
g = CurvePoint.generator()
priv_r = CurveScalar.random_nonzero()
pub_r = g * priv_r
priv_u = CurveScalar.random_nonzero()
pub_u = g * priv_u
h = hash_capsule_points(pub_r, pub_u)
s = priv_u + (priv_r * h)
shared_key = pk._point_key * (priv_r + priv_u)
return cls(point_e=pub_r, point_v=pub_u, signature=s), shared_key
def open_original(self, sk: SecretKey) -> CurvePoint:
return (self.point_e + self.point_v) * sk.secret_scalar()
def _components(self):
return (self.point_e, self.point_v, self.signature)
def _verify(self) -> bool:
g = CurvePoint.generator()
e, v, s = self._components()
h = hash_capsule_points(e, v)
return g * s == v + (e * h)
def __eq__(self, other):
return self._components() == other._components()
def __hash__(self):
return hash((self.__class__, bytes(self)))
def __str__(self):
return f"{self.__class__.__name__}:{bytes(self).hex()[:16]}"

View File

@ -39,17 +39,12 @@ class DEM:
):
self._key = kdf(key_material, self.KEY_SIZE, salt, info)
def encrypt(self, plaintext: bytes, nonce: Optional[bytes] = None) -> bytes:
if nonce is None:
nonce = os.urandom(self.NONCE_SIZE)
if len(nonce) != self.NONCE_SIZE:
raise ValueError(f"The nonce must be exactly {self.NONCE_SIZE} bytes long")
ciphertext = xchacha_encrypt(plaintext, b"", nonce, self._key)
def encrypt(self, plaintext: bytes, authenticated_data: bytes = b"") -> bytes:
nonce = os.urandom(self.NONCE_SIZE)
ciphertext = xchacha_encrypt(plaintext, authenticated_data, nonce, self._key)
return nonce + ciphertext
def decrypt(self, nonce_and_ciphertext: bytes) -> bytes:
def decrypt(self, nonce_and_ciphertext: bytes, authenticated_data: bytes = b"") -> bytes:
if len(nonce_and_ciphertext) < self.NONCE_SIZE:
raise ValueError(f"The ciphertext must include the nonce")
@ -58,4 +53,4 @@ class DEM:
ciphertext = nonce_and_ciphertext[self.NONCE_SIZE:]
# TODO: replace `nacl.exceptions.CryptoError` with our error?
return xchacha_decrypt(ciphertext, b"", nonce, self._key)
return xchacha_decrypt(ciphertext, authenticated_data, nonce, self._key)

View File

@ -23,6 +23,28 @@ class Hash:
return self._sha256.finalize()
def digest_to_scalar(digest: Hash) -> CurveScalar:
# TODO: to be replaced by the standard algroithm.
# Currently just matching what we have in RustCrypto stack.
# Can produce zeros!
hash_digest = openssl._bytes_to_bn(digest.finalize())
bignum = openssl._get_new_BN()
with backend._tmp_bn_ctx() as bn_ctx:
res = backend._lib.BN_mod(bignum, hash_digest, CURVE.order, bn_ctx)
backend.openssl_assert(res == 1)
return CurveScalar(bignum)
def hash_capsule_points(e: CurvePoint, v: CurvePoint) -> CurveScalar:
digest = Hash(b"CAPSULE_POINTS")
digest.update(bytes(e))
digest.update(bytes(v))
return digest_to_scalar(digest)
def unsafe_hash_to_point(dst: bytes, data: bytes) -> 'Point':
"""
Hashes arbitrary data into a valid EC point of the specified curve,

View File

@ -34,6 +34,9 @@ class SecretKey(Serializable):
def __hash__(self):
raise NotImplementedError("Hashing secret objects is insecure")
def secret_scalar(self):
return self._scalar_key
@classmethod
def __take__(cls, data: bytes) -> Tuple['SecretKey', bytes]:
(scalar_key,), data = cls.__take_types__(data, CurveScalar)

29
umbral/pre.py Normal file
View File

@ -0,0 +1,29 @@
from typing import Tuple
from .capsule import Capsule
from .dem import DEM
from .keys import PublicKey, SecretKey
def encrypt(pk: PublicKey, plaintext: bytes) -> Tuple[Capsule, bytes]:
"""
Performs an encryption using the UmbralDEM object and encapsulates a key
for the sender using the public key provided.
Returns the KEM Capsule and the ciphertext.
"""
capsule, key_seed = Capsule.from_public_key(pk)
dem = DEM(bytes(key_seed))
ciphertext = dem.encrypt(plaintext, authenticated_data=bytes(capsule))
return capsule, ciphertext
def decrypt_original(sk: SecretKey, capsule: Capsule, ciphertext: bytes) -> bytes:
"""
Opens the capsule using the original (Alice's) key used for encryption and gets what's inside.
We hope that's a symmetric key, which we use to decrypt the ciphertext
and return the resulting cleartext.
"""
key_seed = capsule.open_original(sk)
dem = DEM(bytes(key_seed))
return dem.decrypt(ciphertext, authenticated_data=bytes(capsule))