Remove repeated casting to bytes from hashing calls

pull/263/head
Bogdan Opanchuk 2021-03-18 19:11:21 -07:00
parent c419705245
commit 6af41b09d9
1 changed files with 20 additions and 19 deletions

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional, Type, Iterable
from typing import TYPE_CHECKING, Optional, Type, Iterable, Union
from cryptography.hazmat.primitives import hashes
@ -7,7 +7,8 @@ from .curve import CURVE
from .curve_scalar import CurveScalar
from .curve_point import CurvePoint
from .keys import PublicKey, SecretKey, Signature
from .serializable import serialize_bool
from .serializable import Serializable, serialize_bool
if TYPE_CHECKING: # pragma: no cover
from .key_frag import KeyFragID
@ -23,8 +24,8 @@ class Hash:
len_dst = len(dst).to_bytes(4, byteorder='big')
self.update(len_dst + dst)
def update(self, data: bytes) -> None:
self._hash.update(data)
def update(self, data: Union[bytes, Serializable]) -> None:
self._hash.update(bytes(data))
def finalize(self) -> bytes:
return self._hash.finalize()
@ -36,17 +37,17 @@ def hash_to_polynomial_arg(precursor: CurvePoint,
kfrag_id: 'KeyFragID',
) -> CurveScalar:
digest = Hash(b"POLYNOMIAL_ARG")
digest.update(bytes(precursor))
digest.update(bytes(pubkey))
digest.update(bytes(dh_point))
digest.update(bytes(kfrag_id))
digest.update(precursor)
digest.update(pubkey)
digest.update(dh_point)
digest.update(kfrag_id)
return CurveScalar.from_digest(digest)
def hash_capsule_points(e: CurvePoint, v: CurvePoint) -> CurveScalar:
digest = Hash(b"CAPSULE_POINTS")
digest.update(bytes(e))
digest.update(bytes(v))
digest.update(e)
digest.update(v)
return CurveScalar.from_digest(digest)
@ -55,9 +56,9 @@ def hash_to_shared_secret(precursor: CurvePoint,
dh_point: CurvePoint
) -> CurveScalar:
digest = Hash(b"SHARED_SECRET")
digest.update(bytes(precursor))
digest.update(bytes(pubkey))
digest.update(bytes(dh_point))
digest.update(precursor)
digest.update(pubkey)
digest.update(dh_point)
return CurveScalar.from_digest(digest)
@ -65,7 +66,7 @@ def hash_to_shared_secret(precursor: CurvePoint,
def hash_to_cfrag_verification(points: Iterable[CurvePoint], metadata: Optional[bytes] = None) -> CurveScalar:
digest = Hash(b"CFRAG_VERIFICATION")
for point in points:
digest.update(bytes(point))
digest.update(point)
if metadata is not None:
digest.update(metadata)
return CurveScalar.from_digest(digest)
@ -79,19 +80,19 @@ def hash_to_cfrag_signature(kfrag_id: 'KeyFragID',
) -> 'SignatureDigest':
digest = SignatureDigest(b"CFRAG_SIGNATURE")
digest.update(bytes(kfrag_id))
digest.update(bytes(commitment))
digest.update(bytes(precursor))
digest.update(kfrag_id)
digest.update(commitment)
digest.update(precursor)
if maybe_delegating_pk:
digest.update(serialize_bool(True))
digest.update(bytes(maybe_delegating_pk))
digest.update(maybe_delegating_pk)
else:
digest.update(serialize_bool(False))
if maybe_receiving_pk:
digest.update(serialize_bool(True))
digest.update(bytes(maybe_receiving_pk))
digest.update(maybe_receiving_pk)
else:
digest.update(serialize_bool(False))