mirror of https://github.com/nucypher/pyUmbral.git
Extract signing code into a separate module
parent
dd76047a42
commit
c4626fa071
|
@ -3,8 +3,7 @@ import string
|
|||
|
||||
import pytest
|
||||
|
||||
from umbral.keys import PublicKey, SecretKey, SecretKeyFactory, Signature
|
||||
from umbral.hashing import Hash
|
||||
from umbral.keys import PublicKey, SecretKey, SecretKeyFactory
|
||||
|
||||
|
||||
def test_gen_key():
|
||||
|
@ -117,87 +116,3 @@ def test_public_key_is_hashable():
|
|||
|
||||
pk3 = PublicKey.from_bytes(bytes(pk))
|
||||
assert hash(pk) == hash(pk3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('execution_number', range(20)) # Run this test 20 times.
|
||||
def test_sign_and_verify(execution_number):
|
||||
sk = SecretKey.random()
|
||||
pk = PublicKey.from_secret_key(sk)
|
||||
|
||||
message = b"peace at dawn"
|
||||
dst = b"dst"
|
||||
|
||||
digest = Hash(dst)
|
||||
digest.update(message)
|
||||
signature = sk.sign_digest(digest)
|
||||
|
||||
digest = Hash(dst)
|
||||
digest.update(message)
|
||||
assert signature.verify_digest(pk, digest)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('execution_number', range(20)) # Run this test 20 times.
|
||||
def test_sign_serialize_and_verify(execution_number):
|
||||
sk = SecretKey.random()
|
||||
pk = PublicKey.from_secret_key(sk)
|
||||
|
||||
message = b"peace at dawn"
|
||||
dst = b"dst"
|
||||
|
||||
digest = Hash(dst)
|
||||
digest.update(message)
|
||||
signature = sk.sign_digest(digest)
|
||||
|
||||
signature_bytes = bytes(signature)
|
||||
signature_restored = Signature.from_bytes(signature_bytes)
|
||||
|
||||
digest = Hash(dst)
|
||||
digest.update(message)
|
||||
assert signature_restored.verify_digest(pk, digest)
|
||||
|
||||
|
||||
def test_verification_fail():
|
||||
sk = SecretKey.random()
|
||||
pk = PublicKey.from_secret_key(sk)
|
||||
|
||||
message = b"peace at dawn"
|
||||
dst = b"dst"
|
||||
|
||||
digest = Hash(dst)
|
||||
digest.update(message)
|
||||
signature = sk.sign_digest(digest)
|
||||
|
||||
# wrong DST
|
||||
digest = Hash(b"other dst")
|
||||
digest.update(message)
|
||||
assert not signature.verify_digest(pk, digest)
|
||||
|
||||
# wrong message
|
||||
digest = Hash(dst)
|
||||
digest.update(b"no peace at dawn")
|
||||
assert not signature.verify_digest(pk, digest)
|
||||
|
||||
# bad signature
|
||||
signature_bytes = bytes(signature)
|
||||
signature_bytes = b'\x00' + signature_bytes[1:]
|
||||
signature_restored = Signature.from_bytes(signature_bytes)
|
||||
|
||||
digest = Hash(dst)
|
||||
digest.update(message)
|
||||
assert not signature_restored.verify_digest(pk, digest)
|
||||
|
||||
|
||||
def test_signature_repr():
|
||||
|
||||
sk = SecretKey.random()
|
||||
pk = PublicKey.from_secret_key(sk)
|
||||
|
||||
message = b"peace at dawn"
|
||||
dst = b"dst"
|
||||
|
||||
digest = Hash(dst)
|
||||
digest.update(message)
|
||||
signature = sk.sign_digest(digest)
|
||||
|
||||
s = repr(signature)
|
||||
assert 'Signature' in s
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
import pytest
|
||||
|
||||
from umbral.keys import PublicKey, SecretKey
|
||||
from umbral.signing import Signature
|
||||
from umbral.hashing import Hash
|
||||
|
||||
|
||||
@pytest.mark.parametrize('execution_number', range(20)) # Run this test 20 times.
|
||||
def test_sign_and_verify(execution_number):
|
||||
sk = SecretKey.random()
|
||||
pk = PublicKey.from_secret_key(sk)
|
||||
|
||||
message = b"peace at dawn"
|
||||
dst = b"dst"
|
||||
|
||||
digest = Hash(dst)
|
||||
digest.update(message)
|
||||
signature = sk.sign_digest(digest)
|
||||
|
||||
digest = Hash(dst)
|
||||
digest.update(message)
|
||||
assert signature.verify_digest(pk, digest)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('execution_number', range(20)) # Run this test 20 times.
|
||||
def test_sign_serialize_and_verify(execution_number):
|
||||
sk = SecretKey.random()
|
||||
pk = PublicKey.from_secret_key(sk)
|
||||
|
||||
message = b"peace at dawn"
|
||||
dst = b"dst"
|
||||
|
||||
digest = Hash(dst)
|
||||
digest.update(message)
|
||||
signature = sk.sign_digest(digest)
|
||||
|
||||
signature_bytes = bytes(signature)
|
||||
signature_restored = Signature.from_bytes(signature_bytes)
|
||||
|
||||
digest = Hash(dst)
|
||||
digest.update(message)
|
||||
assert signature_restored.verify_digest(pk, digest)
|
||||
|
||||
|
||||
def test_verification_fail():
|
||||
sk = SecretKey.random()
|
||||
pk = PublicKey.from_secret_key(sk)
|
||||
|
||||
message = b"peace at dawn"
|
||||
dst = b"dst"
|
||||
|
||||
digest = Hash(dst)
|
||||
digest.update(message)
|
||||
signature = sk.sign_digest(digest)
|
||||
|
||||
# wrong DST
|
||||
digest = Hash(b"other dst")
|
||||
digest.update(message)
|
||||
assert not signature.verify_digest(pk, digest)
|
||||
|
||||
# wrong message
|
||||
digest = Hash(dst)
|
||||
digest.update(b"no peace at dawn")
|
||||
assert not signature.verify_digest(pk, digest)
|
||||
|
||||
# bad signature
|
||||
signature_bytes = bytes(signature)
|
||||
signature_bytes = b'\x00' + signature_bytes[1:]
|
||||
signature_restored = Signature.from_bytes(signature_bytes)
|
||||
|
||||
digest = Hash(dst)
|
||||
digest.update(message)
|
||||
assert not signature_restored.verify_digest(pk, digest)
|
||||
|
||||
|
||||
def test_signature_repr():
|
||||
|
||||
sk = SecretKey.random()
|
||||
pk = PublicKey.from_secret_key(sk)
|
||||
|
||||
message = b"peace at dawn"
|
||||
dst = b"dst"
|
||||
|
||||
digest = Hash(dst)
|
||||
digest.update(message)
|
||||
signature = sk.sign_digest(digest)
|
||||
|
||||
s = repr(signature)
|
||||
assert 'Signature' in s
|
|
@ -8,6 +8,7 @@ from .errors import GenericError
|
|||
from .key_frag import KeyFrag, generate_kfrags
|
||||
from .keys import SecretKey, PublicKey, SecretKeyFactory
|
||||
from .pre import encrypt, decrypt_original, decrypt_reencrypted, reencrypt
|
||||
from .signing import Signature
|
||||
|
||||
__all__ = [
|
||||
"__title__",
|
||||
|
@ -21,6 +22,7 @@ __all__ = [
|
|||
"SecretKey",
|
||||
"PublicKey",
|
||||
"SecretKeyFactory",
|
||||
"Signature",
|
||||
"Capsule",
|
||||
"KeyFrag",
|
||||
"CapsuleFrag",
|
||||
|
|
|
@ -4,10 +4,11 @@ from .capsule import Capsule
|
|||
from .curve_point import CurvePoint
|
||||
from .curve_scalar import CurveScalar
|
||||
from .hashing import Hash, hash_to_cfrag_verification, hash_to_cfrag_signature
|
||||
from .keys import PublicKey, SecretKey, Signature
|
||||
from .keys import PublicKey, SecretKey
|
||||
from .key_frag import KeyFrag, KeyFragID
|
||||
from .params import PARAMETERS
|
||||
from .serializable import Serializable
|
||||
from .signing import Signature
|
||||
|
||||
|
||||
class CapsuleFragProof(Serializable):
|
||||
|
|
|
@ -6,8 +6,9 @@ from .openssl import backend, ErrorInvalidCompressedPoint
|
|||
from .curve import CURVE
|
||||
from .curve_scalar import CurveScalar
|
||||
from .curve_point import CurvePoint
|
||||
from .keys import PublicKey, SecretKey, Signature
|
||||
from .keys import PublicKey, SecretKey
|
||||
from .serializable import Serializable, serialize_bool
|
||||
from .signing import Signature
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from .key_frag import KeyFragID
|
||||
|
|
|
@ -4,9 +4,10 @@ from typing import Tuple, List, Optional
|
|||
from .curve_point import CurvePoint
|
||||
from .curve_scalar import CurveScalar
|
||||
from .hashing import hash_to_shared_secret, hash_to_cfrag_signature, hash_to_polynomial_arg
|
||||
from .keys import PublicKey, SecretKey, Signature
|
||||
from .keys import PublicKey, SecretKey
|
||||
from .params import PARAMETERS
|
||||
from .serializable import Serializable, serialize_bool, take_bool
|
||||
from .signing import Signature
|
||||
|
||||
|
||||
class KeyFragID(Serializable):
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import os
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.primitives.asymmetric import utils
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import ECDSA
|
||||
|
||||
|
@ -76,49 +75,10 @@ class SecretKey(Serializable):
|
|||
r = CurveScalar.from_int(r_int, check_normalization=False)
|
||||
s = CurveScalar.from_int(s_int, check_normalization=False)
|
||||
|
||||
from .signing import Signature
|
||||
return Signature(r, s)
|
||||
|
||||
|
||||
class Signature(Serializable):
|
||||
"""
|
||||
Wrapper for ECDSA signatures.
|
||||
"""
|
||||
|
||||
def __init__(self, r: CurveScalar, s: CurveScalar):
|
||||
self.r = r
|
||||
self.s = s
|
||||
|
||||
def __repr__(self):
|
||||
return f"ECDSA Signature: {bytes(self).hex()[:15]}"
|
||||
|
||||
def verify_digest(self, verifying_key: 'PublicKey', digest: 'Hash') -> bool:
|
||||
backend_pk = openssl.point_to_pubkey(CURVE, verifying_key.point()._backend_point)
|
||||
signature_algorithm = ECDSA(utils.Prehashed(digest._backend_hash_algorithm))
|
||||
|
||||
message = digest.finalize()
|
||||
signature_der_bytes = utils.encode_dss_signature(int(self.r), int(self.s))
|
||||
|
||||
# TODO: Raise error instead of returning boolean
|
||||
try:
|
||||
backend_pk.verify(signature=signature_der_bytes,
|
||||
data=message,
|
||||
signature_algorithm=signature_algorithm)
|
||||
except InvalidSignature:
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def __take__(cls, data):
|
||||
(r, s), data = cls.__take_types__(data, CurveScalar, CurveScalar)
|
||||
return cls(r, s), data
|
||||
|
||||
def __bytes__(self):
|
||||
return bytes(self.r) + bytes(self.s)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.r == other.r and self.s == other.s
|
||||
|
||||
|
||||
class PublicKey(Serializable):
|
||||
"""
|
||||
Umbral public key.
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.primitives.asymmetric import utils
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import ECDSA
|
||||
|
||||
from . import openssl
|
||||
from .curve import CURVE
|
||||
from .curve_scalar import CurveScalar
|
||||
from .serializable import Serializable
|
||||
|
||||
|
||||
class Signature(Serializable):
|
||||
"""
|
||||
Wrapper for ECDSA signatures.
|
||||
"""
|
||||
|
||||
def __init__(self, r: CurveScalar, s: CurveScalar):
|
||||
self.r = r
|
||||
self.s = s
|
||||
|
||||
def __repr__(self):
|
||||
return f"ECDSA Signature: {bytes(self).hex()[:15]}"
|
||||
|
||||
def verify_digest(self, verifying_key: 'PublicKey', digest: 'Hash') -> bool:
|
||||
backend_pk = openssl.point_to_pubkey(CURVE, verifying_key.point()._backend_point)
|
||||
signature_algorithm = ECDSA(utils.Prehashed(digest._backend_hash_algorithm))
|
||||
|
||||
message = digest.finalize()
|
||||
signature_der_bytes = utils.encode_dss_signature(int(self.r), int(self.s))
|
||||
|
||||
# TODO: Raise error instead of returning boolean
|
||||
try:
|
||||
backend_pk.verify(signature=signature_der_bytes,
|
||||
data=message,
|
||||
signature_algorithm=signature_algorithm)
|
||||
except InvalidSignature:
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def __take__(cls, data):
|
||||
(r, s), data = cls.__take_types__(data, CurveScalar, CurveScalar)
|
||||
return cls(r, s), data
|
||||
|
||||
def __bytes__(self):
|
||||
return bytes(self.r) + bytes(self.s)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.r == other.r and self.s == other.s
|
Loading…
Reference in New Issue