pyUmbral/umbral/hashing.py

128 lines
4.2 KiB
Python

from typing import TYPE_CHECKING, Optional, Iterable, Union, List, cast
from cryptography.hazmat.primitives import hashes
from .openssl import backend, ErrorInvalidCompressedPoint
from .curve import CURVE
from .curve_scalar import CurveScalar
from .curve_point import CurvePoint
from .serializable import Serializable, bool_bytes
if TYPE_CHECKING: # pragma: no cover
from .key_frag import KeyFragID
from .keys import PublicKey
class Hash:
OUTPUT_SIZE = 32
def __init__(self, dst: Optional[bytes] = None):
self._backend_hash_algorithm = hashes.SHA256()
self._hash = hashes.Hash(self._backend_hash_algorithm, backend=backend)
if dst is not None:
len_dst = len(dst).to_bytes(4, byteorder='big')
self.update(len_dst + dst)
def update(self, data: Union[bytes, Serializable]) -> None:
self._hash.update(bytes(data))
def finalize(self) -> bytes:
return self._hash.finalize()
def hash_to_polynomial_arg(precursor: CurvePoint,
pubkey: CurvePoint,
dh_point: CurvePoint,
kfrag_id: 'KeyFragID',
) -> CurveScalar:
digest = Hash(b"POLYNOMIAL_ARG")
digest.update(precursor)
digest.update(pubkey)
digest.update(dh_point)
digest.update(kfrag_id)
return CurveScalar.from_digest(digest)
def hash_capsule_points(e: CurvePoint, v: CurvePoint) -> CurveScalar:
digest = Hash(b"CAPSULE_POINTS")
digest.update(e)
digest.update(v)
return CurveScalar.from_digest(digest)
def hash_to_shared_secret(precursor: CurvePoint,
pubkey: CurvePoint,
dh_point: CurvePoint
) -> CurveScalar:
digest = Hash(b"SHARED_SECRET")
digest.update(precursor)
digest.update(pubkey)
digest.update(dh_point)
return CurveScalar.from_digest(digest)
def hash_to_cfrag_verification(points: Iterable[CurvePoint]) -> CurveScalar:
digest = Hash(b"CFRAG_VERIFICATION")
for point in points:
digest.update(point)
return CurveScalar.from_digest(digest)
def kfrag_signature_message(kfrag_id: 'KeyFragID',
commitment: CurvePoint,
precursor: CurvePoint,
maybe_delegating_pk: Optional['PublicKey'],
maybe_receiving_pk: Optional['PublicKey'],
) -> bytes:
# Have to convert to bytes manually because `mypy` is not smart enough to resolve types.
delegating_part = ([bool_bytes(True), bytes(maybe_delegating_pk)]
if maybe_delegating_pk
else [bool_bytes(False)])
cast(List[Serializable], delegating_part)
receiving_part = ([bool_bytes(True), bytes(maybe_receiving_pk)]
if maybe_receiving_pk
else [bool_bytes(False)])
components = ([bytes(kfrag_id), bytes(commitment), bytes(precursor)] +
delegating_part +
receiving_part)
return b''.join(components)
def unsafe_hash_to_point(dst: bytes, data: bytes) -> CurvePoint:
"""
Hashes arbitrary data into a valid EC point of the specified curve,
using the try-and-increment method.
WARNING: Do not use when the input data is secret, as this implementation is not
in constant time, and hence, it is not safe with respect to timing attacks.
"""
len_data = len(data).to_bytes(4, byteorder='big')
data_with_len = len_data + data
sign = b'\x02'
# We use an internal 32-bit counter as additional input
for i in range(2**32):
ibytes = i.to_bytes(4, byteorder='big')
digest = Hash(dst)
digest.update(data_with_len + ibytes)
point_data = digest.finalize()[:CURVE.field_element_size]
compressed_point = sign + point_data
try:
return CurvePoint.from_bytes(compressed_point)
except ErrorInvalidCompressedPoint:
# If it is not a valid point, continue on
pass
# Only happens with probability 2^(-32)
raise ValueError('Could not hash input into the curve') # pragma: no cover