mirror of https://github.com/nucypher/pyUmbral.git
Add tests
parent
f58a2580dc
commit
c401c52e92
|
@ -0,0 +1,51 @@
|
|||
import pytest
|
||||
|
||||
from umbral import SecretKey, PublicKey, generate_kfrags, encrypt
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def alices_keys():
|
||||
delegating_sk = SecretKey.random()
|
||||
signing_sk = SecretKey.random()
|
||||
return delegating_sk, signing_sk
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bobs_keys():
|
||||
sk = SecretKey.random()
|
||||
pk = PublicKey.from_secret_key(sk)
|
||||
return sk, pk
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kfrags(alices_keys, bobs_keys):
|
||||
delegating_sk, signing_sk = alices_keys
|
||||
receiving_sk, receiving_pk = bobs_keys
|
||||
yield generate_kfrags(delegating_sk=delegating_sk,
|
||||
signing_sk=signing_sk,
|
||||
receiving_pk=receiving_pk,
|
||||
threshold=6, num_kfrags=10)
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def message():
|
||||
message = (b"dnunez [9:30 AM]"
|
||||
b"@Tux we had this super fruitful discussion last night with @jMyles @michwill @KPrasch"
|
||||
b"to sum up: the symmetric ciphertext is now called the 'Chimney'."
|
||||
b"the chimney of the capsule, of course"
|
||||
b"tux [9:32 AM]"
|
||||
b"wat")
|
||||
return message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def capsule_and_ciphertext(alices_keys, message):
|
||||
delegating_sk, _signing_sk = alices_keys
|
||||
capsule, ciphertext = encrypt(PublicKey.from_secret_key(delegating_sk), message)
|
||||
return capsule, ciphertext
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def capsule(capsule_and_ciphertext):
|
||||
capsule, ciphertext = capsule_and_ciphertext
|
||||
return capsule
|
|
@ -0,0 +1,107 @@
|
|||
import pytest
|
||||
|
||||
from umbral import (
|
||||
Capsule,
|
||||
SecretKey,
|
||||
PublicKey,
|
||||
encrypt,
|
||||
decrypt_original,
|
||||
reencrypt,
|
||||
decrypt_reencrypted,
|
||||
generate_kfrags
|
||||
)
|
||||
from umbral.curve_point import CurvePoint
|
||||
|
||||
|
||||
def test_capsule_serialization(alices_keys):
|
||||
|
||||
delegating_sk, _signing_sk = alices_keys
|
||||
delegating_pk = PublicKey.from_secret_key(delegating_sk)
|
||||
|
||||
capsule, _key = Capsule.from_public_key(delegating_pk)
|
||||
new_capsule = Capsule.from_bytes(bytes(capsule))
|
||||
|
||||
assert capsule == new_capsule
|
||||
|
||||
# Deserializing a bad capsule triggers verification error
|
||||
capsule.point_e = CurvePoint.random()
|
||||
capsule_bytes = bytes(capsule)
|
||||
|
||||
with pytest.raises(Capsule.NotValid):
|
||||
Capsule.from_bytes(capsule_bytes)
|
||||
|
||||
|
||||
def test_capsule_is_hashable(alices_keys):
|
||||
|
||||
delegating_sk, _signing_sk = alices_keys
|
||||
delegating_pk = PublicKey.from_secret_key(delegating_sk)
|
||||
|
||||
capsule1, key1 = Capsule.from_public_key(delegating_pk)
|
||||
capsule2, key2 = Capsule.from_public_key(delegating_pk)
|
||||
|
||||
assert capsule1 != capsule2
|
||||
assert key1 != key2
|
||||
assert hash(capsule1) != hash(capsule2)
|
||||
|
||||
new_capsule = Capsule.from_bytes(bytes(capsule1))
|
||||
assert hash(new_capsule) == hash(capsule1)
|
||||
|
||||
|
||||
def test_open_original(alices_keys):
|
||||
|
||||
delegating_sk, _signing_sk = alices_keys
|
||||
delegating_pk = PublicKey.from_secret_key(delegating_sk)
|
||||
|
||||
capsule, key = Capsule.from_public_key(delegating_pk)
|
||||
key_back = capsule.open_original(delegating_sk)
|
||||
assert key == key_back
|
||||
|
||||
|
||||
def test_open_reencrypted(alices_keys, bobs_keys):
|
||||
|
||||
threshold = 6
|
||||
num_kfrags = 10
|
||||
|
||||
delegating_sk, signing_sk = alices_keys
|
||||
receiving_sk, receiving_pk = bobs_keys
|
||||
|
||||
signing_pk = PublicKey.from_secret_key(signing_sk)
|
||||
delegating_pk = PublicKey.from_secret_key(delegating_sk)
|
||||
|
||||
capsule, key = Capsule.from_public_key(delegating_pk)
|
||||
kfrags = generate_kfrags(delegating_sk=delegating_sk,
|
||||
signing_sk=signing_sk,
|
||||
receiving_pk=receiving_pk,
|
||||
threshold=threshold,
|
||||
num_kfrags=num_kfrags)
|
||||
|
||||
cfrags = [reencrypt(capsule, kfrag) for kfrag in kfrags]
|
||||
key_back = capsule.open_reencrypted(receiving_sk, delegating_pk, cfrags[:threshold])
|
||||
assert key_back == key
|
||||
|
||||
# No cfrags at all
|
||||
with pytest.raises(ValueError, match="Empty CapsuleFrag sequence"):
|
||||
capsule.open_reencrypted(receiving_sk, delegating_pk, [])
|
||||
|
||||
# Not enough cfrags
|
||||
with pytest.raises(ValueError, match="Internal validation failed"):
|
||||
capsule.open_reencrypted(receiving_sk, delegating_pk, cfrags[:threshold-1])
|
||||
|
||||
# Repeating cfrags
|
||||
with pytest.raises(ValueError, match="Some of the CapsuleFrags are repeated"):
|
||||
capsule.open_reencrypted(receiving_sk, delegating_pk, [cfrags[0]] + cfrags[:threshold-1])
|
||||
|
||||
# Mismatched cfrags
|
||||
kfrags2 = generate_kfrags(delegating_sk=delegating_sk,
|
||||
signing_sk=signing_sk,
|
||||
receiving_pk=receiving_pk,
|
||||
threshold=threshold,
|
||||
num_kfrags=num_kfrags)
|
||||
cfrags2 = [reencrypt(capsule, kfrag) for kfrag in kfrags2]
|
||||
with pytest.raises(ValueError, match="CapsuleFrags are not pairwise consistent"):
|
||||
capsule.open_reencrypted(receiving_sk, delegating_pk, [cfrags2[0]] + cfrags[:threshold-1])
|
||||
|
||||
|
||||
def test_capsule_str(capsule):
|
||||
s = str(capsule)
|
||||
assert 'Capsule' in s
|
|
@ -0,0 +1,149 @@
|
|||
from umbral import reencrypt, CapsuleFrag, PublicKey, Capsule
|
||||
from umbral.curve_point import CurvePoint
|
||||
|
||||
|
||||
def test_cfrag_serialization(alices_keys, bobs_keys, capsule, kfrags):
|
||||
|
||||
delegating_sk, signing_sk = alices_keys
|
||||
_receiving_sk, receiving_pk = bobs_keys
|
||||
|
||||
signing_pk = PublicKey.from_secret_key(signing_sk)
|
||||
delegating_pk = PublicKey.from_secret_key(delegating_sk)
|
||||
|
||||
metadata = b'This is an example of metadata for re-encryption request'
|
||||
for kfrag in kfrags:
|
||||
cfrag = reencrypt(capsule, kfrag, metadata=metadata)
|
||||
cfrag_bytes = bytes(cfrag)
|
||||
|
||||
new_cfrag = CapsuleFrag.from_bytes(cfrag_bytes)
|
||||
assert new_cfrag == cfrag
|
||||
|
||||
assert new_cfrag.verify(capsule,
|
||||
delegating_pk=delegating_pk,
|
||||
receiving_pk=receiving_pk,
|
||||
signing_pk=signing_pk,
|
||||
metadata=metadata)
|
||||
|
||||
# No metadata
|
||||
assert not new_cfrag.verify(capsule,
|
||||
delegating_pk=delegating_pk,
|
||||
receiving_pk=receiving_pk,
|
||||
signing_pk=signing_pk)
|
||||
|
||||
# Wrong metadata
|
||||
assert not new_cfrag.verify(capsule,
|
||||
delegating_pk=delegating_pk,
|
||||
receiving_pk=receiving_pk,
|
||||
signing_pk=signing_pk,
|
||||
metadata=b'Not the same metadata')
|
||||
|
||||
# Wrong delegating key
|
||||
assert not new_cfrag.verify(capsule,
|
||||
delegating_pk=receiving_pk,
|
||||
receiving_pk=receiving_pk,
|
||||
signing_pk=signing_pk,
|
||||
metadata=metadata)
|
||||
|
||||
# Wrong receiving key
|
||||
assert not new_cfrag.verify(capsule,
|
||||
delegating_pk=delegating_pk,
|
||||
receiving_pk=delegating_pk,
|
||||
signing_pk=signing_pk,
|
||||
metadata=metadata)
|
||||
|
||||
# Wrong signing key
|
||||
assert not new_cfrag.verify(capsule,
|
||||
delegating_pk=delegating_pk,
|
||||
receiving_pk=receiving_pk,
|
||||
signing_pk=receiving_pk,
|
||||
metadata=metadata)
|
||||
|
||||
|
||||
def test_cfrag_serialization_no_metadata(alices_keys, bobs_keys, capsule, kfrags):
|
||||
|
||||
delegating_sk, signing_sk = alices_keys
|
||||
_receiving_sk, receiving_pk = bobs_keys
|
||||
|
||||
signing_pk = PublicKey.from_secret_key(signing_sk)
|
||||
delegating_pk = PublicKey.from_secret_key(delegating_sk)
|
||||
|
||||
for kfrag in kfrags:
|
||||
|
||||
# Create with no metadata
|
||||
cfrag = reencrypt(capsule, kfrag)
|
||||
cfrag_bytes = bytes(cfrag)
|
||||
new_cfrag = CapsuleFrag.from_bytes(cfrag_bytes)
|
||||
|
||||
assert new_cfrag.verify(capsule,
|
||||
delegating_pk=delegating_pk,
|
||||
receiving_pk=receiving_pk,
|
||||
signing_pk=signing_pk)
|
||||
|
||||
assert not new_cfrag.verify(capsule,
|
||||
delegating_pk=delegating_pk,
|
||||
receiving_pk=receiving_pk,
|
||||
signing_pk=signing_pk,
|
||||
metadata=b'some metadata')
|
||||
|
||||
|
||||
def test_cfrag_with_wrong_capsule(alices_keys, bobs_keys,
|
||||
kfrags, capsule_and_ciphertext, message):
|
||||
|
||||
capsule, ciphertext = capsule_and_ciphertext
|
||||
|
||||
delegating_sk, signing_sk = alices_keys
|
||||
delegating_pk = PublicKey.from_secret_key(delegating_sk)
|
||||
|
||||
_receiving_sk, receiving_pk = bobs_keys
|
||||
|
||||
capsule_alice1 = capsule
|
||||
capsule_alice2, _unused_key2 = Capsule.from_public_key(delegating_pk)
|
||||
|
||||
metadata = b"some metadata"
|
||||
cfrag = reencrypt(capsule_alice2, kfrags[0], metadata=metadata)
|
||||
|
||||
assert not cfrag.verify(capsule_alice1,
|
||||
delegating_pk=delegating_pk,
|
||||
receiving_pk=receiving_pk,
|
||||
signing_pk=PublicKey.from_secret_key(signing_sk),
|
||||
metadata=metadata)
|
||||
|
||||
|
||||
def test_cfrag_with_wrong_data(kfrags, alices_keys, bobs_keys, capsule_and_ciphertext, message):
|
||||
|
||||
capsule, ciphertext = capsule_and_ciphertext
|
||||
|
||||
delegating_sk, signing_sk = alices_keys
|
||||
delegating_pk = PublicKey.from_secret_key(delegating_sk)
|
||||
|
||||
_receiving_sk, receiving_pk = bobs_keys
|
||||
|
||||
metadata = b"some metadata"
|
||||
cfrag = reencrypt(capsule, kfrags[0], metadata=metadata)
|
||||
|
||||
# Let's put random garbage in one of the cfrags
|
||||
cfrag.point_e1 = CurvePoint.random()
|
||||
cfrag.point_v1 = CurvePoint.random()
|
||||
|
||||
assert not cfrag.verify(capsule,
|
||||
delegating_pk=delegating_pk,
|
||||
receiving_pk=receiving_pk,
|
||||
signing_pk=PublicKey.from_secret_key(signing_sk),
|
||||
metadata=metadata)
|
||||
|
||||
|
||||
def test_cfrag_is_hashable(capsule, kfrags):
|
||||
|
||||
cfrag0 = reencrypt(capsule, kfrags[0], metadata=b'abcdef')
|
||||
cfrag1 = reencrypt(capsule, kfrags[1], metadata=b'abcdef')
|
||||
|
||||
assert hash(cfrag0) != hash(cfrag1)
|
||||
|
||||
new_cfrag = CapsuleFrag.from_bytes(bytes(cfrag0))
|
||||
assert hash(new_cfrag) == hash(cfrag0)
|
||||
|
||||
|
||||
def test_cfrag_str(capsule, kfrags):
|
||||
cfrag0 = reencrypt(capsule, kfrags[0], metadata=b'abcdef')
|
||||
s = str(cfrag0)
|
||||
assert 'CapsuleFrag' in s
|
|
@ -0,0 +1,205 @@
|
|||
import pytest
|
||||
|
||||
try:
|
||||
import umbral_pre as umbral_rs
|
||||
except ImportError:
|
||||
umbral_rs = None
|
||||
|
||||
import umbral as umbral_py
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
if 'implementations' in metafunc.fixturenames:
|
||||
implementations = [(umbral_py, umbral_py)]
|
||||
ids = ['python -> python']
|
||||
if umbral_rs is not None:
|
||||
implementations.extend([(umbral_py, umbral_rs), (umbral_rs, umbral_py)])
|
||||
ids.extend(['python -> rust', 'rust -> python'])
|
||||
|
||||
metafunc.parametrize('implementations', implementations, ids=ids)
|
||||
|
||||
|
||||
def _create_keypair(umbral):
|
||||
sk = umbral.SecretKey.random()
|
||||
pk = umbral.PublicKey.from_secret_key(sk)
|
||||
return bytes(sk), bytes(pk)
|
||||
|
||||
|
||||
def _restore_keys(umbral, sk_bytes, pk_bytes):
|
||||
sk = umbral.SecretKey.from_bytes(sk_bytes)
|
||||
pk_from_sk = umbral.PublicKey.from_secret_key(sk)
|
||||
pk_from_bytes = umbral.PublicKey.from_bytes(pk_bytes)
|
||||
assert pk_from_sk == pk_from_bytes
|
||||
|
||||
|
||||
def test_keys(implementations):
|
||||
umbral1, umbral2 = implementations
|
||||
|
||||
# On client 1
|
||||
sk_bytes, pk_bytes = _create_keypair(umbral1)
|
||||
|
||||
# On client 2
|
||||
_restore_keys(umbral2, sk_bytes, pk_bytes)
|
||||
|
||||
|
||||
def _create_sk_factory_and_sk(umbral, label):
|
||||
skf = umbral.SecretKeyFactory.random()
|
||||
sk = skf.secret_key_by_label(label)
|
||||
return bytes(skf), bytes(sk)
|
||||
|
||||
|
||||
def _check_sk_is_same(umbral, label, skf_bytes, sk_bytes):
|
||||
skf = umbral.SecretKeyFactory.from_bytes(skf_bytes)
|
||||
sk_restored = umbral.SecretKey.from_bytes(sk_bytes)
|
||||
sk_generated = skf.secret_key_by_label(label)
|
||||
assert sk_restored == sk_generated
|
||||
|
||||
|
||||
def test_secret_key_factory(implementations):
|
||||
umbral1, umbral2 = implementations
|
||||
label = b'label'
|
||||
|
||||
skf_bytes, sk_bytes = _create_sk_factory_and_sk(umbral1, label)
|
||||
_check_sk_is_same(umbral2, label, skf_bytes, sk_bytes)
|
||||
|
||||
|
||||
def _encrypt(umbral, plaintext, pk_bytes):
|
||||
pk = umbral.PublicKey.from_bytes(pk_bytes)
|
||||
capsule, ciphertext = umbral.encrypt(pk, plaintext)
|
||||
return bytes(capsule), ciphertext
|
||||
|
||||
|
||||
def _decrypt_original(umbral, sk_bytes, capsule_bytes, ciphertext):
|
||||
capsule = umbral.Capsule.from_bytes(bytes(capsule_bytes))
|
||||
sk = umbral.SecretKey.from_bytes(sk_bytes)
|
||||
return umbral.decrypt_original(sk, capsule, ciphertext)
|
||||
|
||||
|
||||
def test_encrypt_decrypt(implementations):
|
||||
|
||||
umbral1, umbral2 = implementations
|
||||
plaintext = b'peace at dawn'
|
||||
|
||||
# On client 1
|
||||
sk_bytes, pk_bytes = _create_keypair(umbral1)
|
||||
|
||||
# On client 2
|
||||
capsule_bytes, ciphertext = _encrypt(umbral2, plaintext, pk_bytes)
|
||||
|
||||
# On client 1
|
||||
plaintext_decrypted = _decrypt_original(umbral1, sk_bytes, capsule_bytes, ciphertext)
|
||||
|
||||
assert plaintext_decrypted == plaintext
|
||||
|
||||
|
||||
def _generate_kfrags(umbral, delegating_sk_bytes, receiving_pk_bytes,
|
||||
signing_sk_bytes, threshold, num_frags):
|
||||
|
||||
delegating_sk = umbral.SecretKey.from_bytes(delegating_sk_bytes)
|
||||
receiving_pk = umbral.PublicKey.from_bytes(receiving_pk_bytes)
|
||||
signing_sk = umbral.SecretKey.from_bytes(signing_sk_bytes)
|
||||
|
||||
kfrags = umbral.generate_kfrags(delegating_sk,
|
||||
receiving_pk,
|
||||
signing_sk,
|
||||
threshold,
|
||||
num_frags,
|
||||
True,
|
||||
True,
|
||||
)
|
||||
|
||||
return [bytes(kfrag) for kfrag in kfrags]
|
||||
|
||||
|
||||
def _verify_kfrags(umbral, kfrags_bytes, signing_pk_bytes, delegating_pk_bytes, receiving_pk_bytes):
|
||||
kfrags = [umbral.KeyFrag.from_bytes(kfrag_bytes) for kfrag_bytes in kfrags_bytes]
|
||||
signing_pk = umbral.PublicKey.from_bytes(signing_pk_bytes)
|
||||
delegating_pk = umbral.PublicKey.from_bytes(delegating_pk_bytes)
|
||||
receiving_pk = umbral.PublicKey.from_bytes(receiving_pk_bytes)
|
||||
assert all(kfrag.verify(signing_pk, delegating_pk, receiving_pk) for kfrag in kfrags)
|
||||
|
||||
|
||||
def test_kfrags(implementations):
|
||||
|
||||
umbral1, umbral2 = implementations
|
||||
|
||||
threshold = 2
|
||||
num_frags = 3
|
||||
plaintext = b'peace at dawn'
|
||||
|
||||
# On client 1
|
||||
|
||||
receiving_sk_bytes, receiving_pk_bytes = _create_keypair(umbral1)
|
||||
delegating_sk_bytes, delegating_pk_bytes = _create_keypair(umbral1)
|
||||
signing_sk_bytes, signing_pk_bytes = _create_keypair(umbral1)
|
||||
kfrags_bytes = _generate_kfrags(umbral1, delegating_sk_bytes, receiving_pk_bytes,
|
||||
signing_sk_bytes, threshold, num_frags)
|
||||
|
||||
# On client 2
|
||||
|
||||
_verify_kfrags(umbral2, kfrags_bytes, signing_pk_bytes, delegating_pk_bytes, receiving_pk_bytes)
|
||||
|
||||
|
||||
def _reencrypt(umbral, capsule_bytes, kfrags_bytes, threshold, metadata):
|
||||
capsule = umbral.Capsule.from_bytes(bytes(capsule_bytes))
|
||||
kfrags = [umbral.KeyFrag.from_bytes(kfrag_bytes) for kfrag_bytes in kfrags_bytes]
|
||||
cfrags = [umbral.reencrypt(capsule, kfrag, metadata=metadata) for kfrag in kfrags[:threshold]]
|
||||
return [bytes(cfrag) for cfrag in cfrags]
|
||||
|
||||
|
||||
def _decrypt_reencrypted(umbral, receiving_sk_bytes, delegating_pk_bytes, signing_pk_bytes,
|
||||
capsule_bytes, cfrags_bytes, ciphertext, metadata):
|
||||
|
||||
receiving_sk = umbral.SecretKey.from_bytes(receiving_sk_bytes)
|
||||
receiving_pk = umbral.PublicKey.from_secret_key(receiving_sk)
|
||||
delegating_pk = umbral.PublicKey.from_bytes(delegating_pk_bytes)
|
||||
signing_pk = umbral.PublicKey.from_bytes(signing_pk_bytes)
|
||||
|
||||
capsule = umbral.Capsule.from_bytes(bytes(capsule_bytes))
|
||||
cfrags = [umbral.CapsuleFrag.from_bytes(cfrag_bytes) for cfrag_bytes in cfrags_bytes]
|
||||
|
||||
assert all(cfrag.verify(capsule, delegating_pk, receiving_pk, signing_pk, metadata=metadata)
|
||||
for cfrag in cfrags)
|
||||
|
||||
# Decryption by Bob
|
||||
plaintext = umbral.decrypt_reencrypted(receiving_sk,
|
||||
delegating_pk,
|
||||
capsule,
|
||||
cfrags,
|
||||
ciphertext,
|
||||
)
|
||||
|
||||
return plaintext
|
||||
|
||||
|
||||
def test_reencrypt(implementations):
|
||||
|
||||
umbral1, umbral2 = implementations
|
||||
|
||||
metadata = b'metadata'
|
||||
threshold = 2
|
||||
num_frags = 3
|
||||
plaintext = b'peace at dawn'
|
||||
|
||||
# On client 1
|
||||
|
||||
receiving_sk_bytes, receiving_pk_bytes = _create_keypair(umbral1)
|
||||
delegating_sk_bytes, delegating_pk_bytes = _create_keypair(umbral1)
|
||||
signing_sk_bytes, signing_pk_bytes = _create_keypair(umbral1)
|
||||
|
||||
capsule_bytes, ciphertext = _encrypt(umbral1, plaintext, delegating_pk_bytes)
|
||||
|
||||
kfrags_bytes = _generate_kfrags(umbral1, delegating_sk_bytes, receiving_pk_bytes,
|
||||
signing_sk_bytes, threshold, num_frags)
|
||||
|
||||
# On client 2
|
||||
|
||||
cfrags_bytes = _reencrypt(umbral2, capsule_bytes, kfrags_bytes, threshold, metadata)
|
||||
|
||||
# On client 1
|
||||
|
||||
plaintext_reencrypted = _decrypt_reencrypted(umbral1,
|
||||
receiving_sk_bytes, delegating_pk_bytes, signing_pk_bytes,
|
||||
capsule_bytes, cfrags_bytes, ciphertext, metadata)
|
||||
|
||||
assert plaintext_reencrypted == plaintext
|
|
@ -0,0 +1,113 @@
|
|||
import pytest
|
||||
|
||||
from umbral.openssl import Curve, bn_to_int, point_to_affine_coords
|
||||
from umbral.curve import CURVE, CURVES, SECP256R1, SECP256K1, SECP384R1
|
||||
|
||||
|
||||
def test_supported_curves():
|
||||
|
||||
# Ensure we have the correct number of supported curves hardcoded
|
||||
number_of_supported_curves = 3
|
||||
assert len(Curve._supported_curves) == number_of_supported_curves
|
||||
|
||||
# Manually ensure the `_supported curves` dict contains only valid supported curves
|
||||
assert Curve._supported_curves[415] == 'secp256r1'
|
||||
assert Curve._supported_curves[714] == 'secp256k1'
|
||||
assert Curve._supported_curves[715] == 'secp384r1'
|
||||
|
||||
|
||||
def test_create_by_nid():
|
||||
|
||||
nid, name = 714, 'secp256k1'
|
||||
|
||||
# supported
|
||||
_curve_714 = Curve(nid=nid)
|
||||
assert _curve_714.nid == nid
|
||||
assert _curve_714.name == name
|
||||
|
||||
# unsuported
|
||||
with pytest.raises(NotImplementedError):
|
||||
Curve(711)
|
||||
|
||||
|
||||
def test_create_by_name():
|
||||
|
||||
nid, name = 714, 'secp256k1'
|
||||
|
||||
# Supported
|
||||
_curve_secp256k1 = Curve.from_name(name)
|
||||
assert _curve_secp256k1.name == name
|
||||
assert _curve_secp256k1.nid == nid
|
||||
|
||||
# Unsupported
|
||||
with pytest.raises(NotImplementedError):
|
||||
Curve.from_name('abcd123e4')
|
||||
|
||||
|
||||
def test_curve_constants():
|
||||
|
||||
test_p256 = SECP256R1
|
||||
test_secp256k1 = SECP256K1
|
||||
test_p384 = SECP384R1
|
||||
|
||||
assert CURVE == SECP256K1
|
||||
|
||||
# Test the hardcoded curve NIDs are correct:
|
||||
assert test_p256.nid == 415
|
||||
assert test_secp256k1.nid == 714
|
||||
assert test_p384.nid == 715
|
||||
|
||||
# Ensure every curve constant is in the CURVES collection
|
||||
number_of_supported_curves = 3
|
||||
assert len(CURVES) == number_of_supported_curves
|
||||
|
||||
# Ensure all supported curves can be initialized
|
||||
for nid, name in Curve._supported_curves.items():
|
||||
by_nid, by_name = Curve(nid=nid), Curve.from_name(name)
|
||||
assert by_nid.name == name
|
||||
assert by_name.nid == nid
|
||||
|
||||
|
||||
def test_curve_str():
|
||||
for curve in CURVES:
|
||||
s = str(curve)
|
||||
assert str(curve.nid) in s
|
||||
assert str(curve.name) in s
|
||||
|
||||
|
||||
def _curve_info(curve: Curve):
|
||||
assert bn_to_int(curve.bn_order) == curve.order
|
||||
return dict(order=curve.order,
|
||||
field_element_size=curve.field_element_size,
|
||||
scalar_size=curve.scalar_size,
|
||||
generator=point_to_affine_coords(curve, curve.point_generator))
|
||||
|
||||
|
||||
def test_secp256k1():
|
||||
info = _curve_info(SECP256K1)
|
||||
assert info['order'] == 0xFFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFE_BAAEDCE6_AF48A03B_BFD25E8C_D0364141
|
||||
assert info['field_element_size'] == 32
|
||||
assert info['scalar_size'] == 32
|
||||
assert info['generator'] == (
|
||||
0x79BE667E_F9DCBBAC_55A06295_CE870B07_029BFCDB_2DCE28D9_59F2815B_16F81798,
|
||||
0x483ADA77_26A3C465_5DA4FBFC_0E1108A8_FD17B448_A6855419_9C47D08F_FB10D4B8)
|
||||
|
||||
|
||||
def test_p256():
|
||||
info = _curve_info(SECP256R1)
|
||||
assert info['order'] == 0xFFFFFFFF_00000000_FFFFFFFF_FFFFFFFF_BCE6FAAD_A7179E84_F3B9CAC2_FC632551
|
||||
assert info['field_element_size'] == 32
|
||||
assert info['scalar_size'] == 32
|
||||
assert info['generator'] == (
|
||||
0x6B17D1F2_E12C4247_F8BCE6E5_63A440F2_77037D81_2DEB33A0_F4A13945_D898C296,
|
||||
0x4FE342E2_FE1A7F9B_8EE7EB4A_7C0F9E16_2BCE3357_6B315ECE_CBB64068_37BF51F5)
|
||||
|
||||
|
||||
def test_p384():
|
||||
info = _curve_info(SECP384R1)
|
||||
assert info['order'] == 0xFFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_C7634D81_F4372DDF_581A0DB2_48B0A77A_ECEC196A_CCC52973
|
||||
assert info['field_element_size'] == 48
|
||||
assert info['scalar_size'] == 48
|
||||
assert info['generator'] == (
|
||||
0xAA87CA22_BE8B0537_8EB1C71E_F320AD74_6E1D3B62_8BA79B98_59F741E0_82542A38_5502F25D_BF55296C_3A545E38_72760AB7,
|
||||
0x3617DE4A_96262C6F_5D9E98BF_9292DC29_F8F41DBD_289A147C_E9DA3113_B5F0B8C0_0A60B1CE_1D7E819D_7A431D7C_90EA0E5F)
|
|
@ -0,0 +1,90 @@
|
|||
import pytest
|
||||
|
||||
from umbral.openssl import ErrorInvalidCompressedPoint, ErrorInvalidPointEncoding
|
||||
from umbral.curve_point import CurvePoint
|
||||
from umbral.curve import CURVE
|
||||
|
||||
|
||||
def test_random():
|
||||
p1 = CurvePoint.random()
|
||||
p2 = CurvePoint.random()
|
||||
assert isinstance(p1, CurvePoint)
|
||||
assert isinstance(p2, CurvePoint)
|
||||
assert p1 != p2
|
||||
|
||||
|
||||
def test_generator_point():
|
||||
"""http://www.secg.org/SEC2-Ver-1.0.pdf Section 2.7.1"""
|
||||
g1 = CurvePoint.generator()
|
||||
|
||||
g_compressed = 0x0279BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798
|
||||
g_compressed_bytes = g_compressed.to_bytes(CURVE.field_element_size + 1, byteorder='big')
|
||||
g2 = CurvePoint.from_bytes(g_compressed_bytes)
|
||||
|
||||
assert g1 == g2
|
||||
|
||||
|
||||
def test_to_and_from_affine():
|
||||
|
||||
x = 17004608369308732328368332205668001941491834793934321461466076545247324070015
|
||||
y = 69725941631324401609944843130171147910924748427773762412028916504484868631573
|
||||
|
||||
p = CurvePoint.from_affine(x, y)
|
||||
|
||||
assert p.to_affine() == (x, y)
|
||||
|
||||
|
||||
def test_invalid_serialized_points():
|
||||
|
||||
field_order = 2**256 - 0x1000003D1
|
||||
|
||||
# A point on secp256k1
|
||||
x = 17004608369308732328368332205668001941491834793934321461466076545247324070015
|
||||
y = 69725941631324401609944843130171147910924748427773762412028916504484868631573
|
||||
|
||||
# Check it
|
||||
assert (y**2 - x**3 - 7) % field_order == 0
|
||||
|
||||
# Should load
|
||||
point_data = b'\x03' + x.to_bytes(CURVE.field_element_size, 'big')
|
||||
p = CurvePoint.from_bytes(point_data)
|
||||
|
||||
# Make it invalid
|
||||
bad_x = x - 1
|
||||
assert (y**2 - bad_x**3 - 7) % field_order != 0
|
||||
|
||||
bad_x_data = b'\x03' + bad_x.to_bytes(CURVE.field_element_size, 'big')
|
||||
with pytest.raises(ErrorInvalidCompressedPoint):
|
||||
CurvePoint.from_bytes(bad_x_data)
|
||||
|
||||
# Valid x, invalid prefix
|
||||
bad_format = b'\xff' + x.to_bytes(CURVE.field_element_size, 'big')
|
||||
with pytest.raises(ErrorInvalidPointEncoding):
|
||||
CurvePoint.from_bytes(bad_format)
|
||||
|
||||
|
||||
def test_serialize_point_at_infinity():
|
||||
|
||||
p = CurvePoint.random()
|
||||
point_at_infinity = p - p
|
||||
|
||||
bytes_point_at_infinity = bytes(point_at_infinity)
|
||||
assert bytes_point_at_infinity == b'\x00'
|
||||
|
||||
|
||||
def test_coords_with_special_characteristics():
|
||||
|
||||
# Testing that a point with x coordinate greater than the curve order is still valid.
|
||||
# In particular, we will test the last valid point from the default curve (secp256k1)
|
||||
# whose x coordinate is `field_order - 3` and is greater than the order of the curve
|
||||
|
||||
field_order = 2**256 - 0x1000003D1
|
||||
compressed = b'\x02' + (field_order-3).to_bytes(32, 'big')
|
||||
|
||||
last_point = CurvePoint.from_bytes(compressed)
|
||||
|
||||
# The same point, but obtained through the from_affine method
|
||||
x = 115792089237316195423570985008687907853269984665640564039457584007908834671660
|
||||
y = 109188863561374057667848968960504138135859662956057034999983532397866404169138
|
||||
|
||||
assert last_point == CurvePoint.from_affine(x, y)
|
|
@ -0,0 +1,127 @@
|
|||
import pytest
|
||||
|
||||
from umbral.curve import CURVE
|
||||
from umbral.curve_scalar import CurveScalar
|
||||
from umbral.hashing import Hash
|
||||
|
||||
|
||||
def test_random():
|
||||
r1 = CurveScalar.random_nonzero()
|
||||
r2 = CurveScalar.random_nonzero()
|
||||
assert r1 != r2
|
||||
assert not r1.is_zero()
|
||||
assert not r2.is_zero()
|
||||
|
||||
|
||||
def test_from_and_to_int():
|
||||
zero = CurveScalar.from_int(0)
|
||||
assert zero.is_zero()
|
||||
assert int(zero) == 0
|
||||
|
||||
one = CurveScalar.one()
|
||||
assert not one.is_zero()
|
||||
assert int(one) == 1
|
||||
|
||||
big_int = CURVE.order - 2
|
||||
big_scalar = CurveScalar.from_int(big_int)
|
||||
assert int(big_scalar) == big_int
|
||||
|
||||
# normalization check
|
||||
with pytest.raises(ValueError):
|
||||
CurveScalar.from_int(CURVE.order)
|
||||
|
||||
# disable normalization check
|
||||
too_big = CurveScalar.from_int(CURVE.order, check_normalization=False)
|
||||
|
||||
|
||||
def test_from_digest():
|
||||
digest = Hash(b'asdf')
|
||||
digest.update(b'some info')
|
||||
s1 = CurveScalar.from_digest(digest)
|
||||
|
||||
digest = Hash(b'asdf')
|
||||
digest.update(b'some info')
|
||||
s2 = CurveScalar.from_digest(digest)
|
||||
|
||||
assert s1 == s2
|
||||
assert int(s1) == int(s2)
|
||||
|
||||
|
||||
def test_eq():
|
||||
random = CurveScalar.random_nonzero()
|
||||
same = CurveScalar.from_int(int(random))
|
||||
different = CurveScalar.random_nonzero()
|
||||
assert random == same
|
||||
assert random == int(same)
|
||||
assert random != different
|
||||
assert random != int(different)
|
||||
|
||||
|
||||
def test_serialization_rotations_of_1():
|
||||
|
||||
size_in_bytes = CURVE.scalar_size
|
||||
for i in range(size_in_bytes):
|
||||
lonely_one = 1 << i
|
||||
bn = CurveScalar.from_int(lonely_one)
|
||||
lonely_one_in_bytes = lonely_one.to_bytes(size_in_bytes, 'big')
|
||||
|
||||
# Check serialization
|
||||
assert bytes(bn) == lonely_one_in_bytes
|
||||
|
||||
# Check deserialization
|
||||
assert CurveScalar.from_bytes(lonely_one_in_bytes) == bn
|
||||
|
||||
|
||||
def test_invalid_deserialization():
|
||||
size_in_bytes = CURVE.scalar_size
|
||||
|
||||
# All-ones bytestring is invalid (since it's greater than the order)
|
||||
lots_of_ones = b'\xFF' * size_in_bytes
|
||||
with pytest.raises(ValueError):
|
||||
CurveScalar.from_bytes(lots_of_ones)
|
||||
|
||||
# Serialization of `order` is invalid since it's not strictly lower than
|
||||
# the order of the curve
|
||||
order = CURVE.order
|
||||
with pytest.raises(ValueError):
|
||||
CurveScalar.from_bytes(order.to_bytes(size_in_bytes, 'big'))
|
||||
|
||||
# On the other hand, serialization of `order - 1` is valid
|
||||
order -= 1
|
||||
CurveScalar.from_bytes(order.to_bytes(size_in_bytes, 'big'))
|
||||
|
||||
|
||||
def test_add():
|
||||
r1 = CurveScalar.random_nonzero()
|
||||
r2 = CurveScalar.random_nonzero()
|
||||
r1i = int(r1)
|
||||
r2i = int(r2)
|
||||
assert r1 + r2 == (r1i + r2i) % CURVE.order
|
||||
assert r1 + r2i == (r1i + r2i) % CURVE.order
|
||||
|
||||
|
||||
def test_sub():
|
||||
r1 = CurveScalar.random_nonzero()
|
||||
r2 = CurveScalar.random_nonzero()
|
||||
r1i = int(r1)
|
||||
r2i = int(r2)
|
||||
assert r1 - r2 == (r1i - r2i) % CURVE.order
|
||||
assert r1 - r2i == (r1i - r2i) % CURVE.order
|
||||
|
||||
|
||||
def test_mul():
|
||||
r1 = CurveScalar.random_nonzero()
|
||||
r2 = CurveScalar.random_nonzero()
|
||||
r1i = int(r1)
|
||||
r2i = int(r2)
|
||||
assert r1 * r2 == (r1i * r2i) % CURVE.order
|
||||
assert r1 * r2i == (r1i * r2i) % CURVE.order
|
||||
|
||||
|
||||
def test_invert():
|
||||
r1 = CurveScalar.random_nonzero()
|
||||
r1i = int(r1)
|
||||
r1inv = r1.invert()
|
||||
assert r1 * r1inv == CurveScalar.one()
|
||||
assert (r1i * int(r1inv)) % CURVE.order == 1
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
import pytest
|
||||
import os
|
||||
|
||||
from umbral.dem import DEM, ErrorInvalidTag
|
||||
|
||||
|
||||
def test_encrypt_decrypt():
|
||||
|
||||
key = os.urandom(DEM.KEY_SIZE)
|
||||
dem = DEM(key)
|
||||
|
||||
plaintext = b'peace at dawn'
|
||||
|
||||
ciphertext0 = dem.encrypt(plaintext)
|
||||
ciphertext1 = dem.encrypt(plaintext)
|
||||
|
||||
assert ciphertext0 != plaintext
|
||||
assert ciphertext1 != plaintext
|
||||
|
||||
# Ciphertext should be different even with same plaintext.
|
||||
assert ciphertext0 != ciphertext1
|
||||
|
||||
# Nonce should be different
|
||||
assert ciphertext0[:DEM.NONCE_SIZE] != ciphertext1[:DEM.NONCE_SIZE]
|
||||
|
||||
cleartext0 = dem.decrypt(ciphertext0)
|
||||
cleartext1 = dem.decrypt(ciphertext1)
|
||||
|
||||
assert cleartext0 == plaintext
|
||||
assert cleartext1 == plaintext
|
||||
|
||||
|
||||
def test_malformed_ciphertext():
|
||||
|
||||
key = os.urandom(DEM.KEY_SIZE)
|
||||
dem = DEM(key)
|
||||
|
||||
plaintext = b'peace at dawn'
|
||||
ciphertext = dem.encrypt(plaintext)
|
||||
|
||||
# So short it we can tell right away it doesn't even contain a nonce
|
||||
with pytest.raises(ValueError, match="The ciphertext must include the nonce"):
|
||||
dem.decrypt(ciphertext[:DEM.NONCE_SIZE-1])
|
||||
|
||||
# Too short to contain a tag
|
||||
with pytest.raises(ValueError, match="The authentication tag is missing or malformed"):
|
||||
dem.decrypt(ciphertext[:DEM.NONCE_SIZE + DEM.TAG_SIZE - 1])
|
||||
|
||||
# Too long
|
||||
with pytest.raises(ErrorInvalidTag):
|
||||
dem.decrypt(ciphertext + b'abcd')
|
||||
|
||||
|
||||
def test_encrypt_decrypt_associated_data():
|
||||
key = os.urandom(32)
|
||||
aad = b'secret code 1234'
|
||||
|
||||
dem = DEM(key)
|
||||
|
||||
plaintext = b'peace at dawn'
|
||||
|
||||
ciphertext0 = dem.encrypt(plaintext, authenticated_data=aad)
|
||||
ciphertext1 = dem.encrypt(plaintext, authenticated_data=aad)
|
||||
|
||||
assert ciphertext0 != plaintext
|
||||
assert ciphertext1 != plaintext
|
||||
|
||||
assert ciphertext0 != ciphertext1
|
||||
|
||||
assert ciphertext0[:DEM.NONCE_SIZE] != ciphertext1[:DEM.NONCE_SIZE]
|
||||
|
||||
cleartext0 = dem.decrypt(ciphertext0, authenticated_data=aad)
|
||||
cleartext1 = dem.decrypt(ciphertext1, authenticated_data=aad)
|
||||
|
||||
assert cleartext0 == plaintext
|
||||
assert cleartext1 == plaintext
|
||||
|
||||
# Attempt decryption with invalid associated data
|
||||
with pytest.raises(ErrorInvalidTag):
|
||||
cleartext2 = dem.decrypt(ciphertext0, authenticated_data=b'wrong data')
|
|
@ -0,0 +1,126 @@
|
|||
import pytest
|
||||
|
||||
from umbral import KeyFrag, PublicKey, generate_kfrags
|
||||
from umbral.key_frag import KeyFragID
|
||||
from umbral.curve_scalar import CurveScalar
|
||||
|
||||
|
||||
def test_kfrag_serialization(alices_keys, bobs_keys, kfrags):
|
||||
|
||||
delegating_sk, signing_sk = alices_keys
|
||||
_receiving_sk, receiving_pk = bobs_keys
|
||||
|
||||
signing_pk = PublicKey.from_secret_key(signing_sk)
|
||||
delegating_pk = PublicKey.from_secret_key(delegating_sk)
|
||||
|
||||
for kfrag in kfrags:
|
||||
kfrag_bytes = bytes(kfrag)
|
||||
new_kfrag = KeyFrag.from_bytes(kfrag_bytes)
|
||||
|
||||
assert new_kfrag.verify(signing_pk=signing_pk,
|
||||
delegating_pk=delegating_pk,
|
||||
receiving_pk=receiving_pk)
|
||||
|
||||
assert new_kfrag == kfrag
|
||||
|
||||
|
||||
def test_kfrag_verification(alices_keys, bobs_keys, kfrags):
|
||||
|
||||
delegating_sk, signing_sk = alices_keys
|
||||
_receiving_sk, receiving_pk = bobs_keys
|
||||
|
||||
signing_pk = PublicKey.from_secret_key(signing_sk)
|
||||
delegating_pk = PublicKey.from_secret_key(delegating_sk)
|
||||
|
||||
# Wrong signature
|
||||
kfrag = kfrags[0]
|
||||
kfrag.id = KeyFragID.random()
|
||||
kfrag_bytes = bytes(kfrag)
|
||||
new_kfrag = KeyFrag.from_bytes(kfrag_bytes)
|
||||
assert not new_kfrag.verify(signing_pk=signing_pk,
|
||||
delegating_pk=delegating_pk,
|
||||
receiving_pk=receiving_pk)
|
||||
|
||||
# Wrong key
|
||||
kfrag = kfrags[1]
|
||||
kfrag.key = CurveScalar.random_nonzero()
|
||||
kfrag_bytes = bytes(kfrag)
|
||||
new_kfrag = KeyFrag.from_bytes(kfrag_bytes)
|
||||
assert not new_kfrag.verify(signing_pk=signing_pk,
|
||||
delegating_pk=delegating_pk,
|
||||
receiving_pk=receiving_pk)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('sign_delegating_key',
|
||||
[False, True],
|
||||
ids=['sign_delegating_key', 'dont_sign_delegating_key'])
|
||||
@pytest.mark.parametrize('sign_receiving_key',
|
||||
[False, True],
|
||||
ids=['sign_receiving_key', 'dont_sign_receiving_key'])
|
||||
def test_kfrag_signing(alices_keys, bobs_keys, sign_delegating_key, sign_receiving_key):
|
||||
|
||||
delegating_sk, signing_sk = alices_keys
|
||||
_receiving_sk, receiving_pk = bobs_keys
|
||||
|
||||
signing_pk = PublicKey.from_secret_key(signing_sk)
|
||||
delegating_pk = PublicKey.from_secret_key(delegating_sk)
|
||||
|
||||
kfrags = generate_kfrags(delegating_sk=delegating_sk,
|
||||
signing_sk=signing_sk,
|
||||
receiving_pk=receiving_pk,
|
||||
threshold=6,
|
||||
num_kfrags=10,
|
||||
sign_delegating_key=sign_delegating_key,
|
||||
sign_receiving_key=sign_receiving_key)
|
||||
|
||||
kfrag = kfrags[0]
|
||||
|
||||
# serialize/deserialize to make sure sign_* fields are serialized correctly
|
||||
kfrag = KeyFrag.from_bytes(bytes(kfrag))
|
||||
|
||||
for pass_delegating_key, pass_receiving_key in zip([False, True], [False, True]):
|
||||
|
||||
delegating_key_ok = (not sign_delegating_key) or pass_delegating_key
|
||||
receiving_key_ok = (not sign_receiving_key) or pass_receiving_key
|
||||
should_verify = delegating_key_ok and receiving_key_ok
|
||||
|
||||
result = kfrag.verify(signing_pk=signing_pk,
|
||||
delegating_pk=delegating_pk if pass_delegating_key else None,
|
||||
receiving_pk=receiving_pk if pass_receiving_key else None)
|
||||
|
||||
assert result == should_verify
|
||||
|
||||
|
||||
def test_kfrag_is_hashable(kfrags):
|
||||
|
||||
assert hash(kfrags[0]) != hash(kfrags[1])
|
||||
|
||||
new_kfrag = KeyFrag.from_bytes(bytes(kfrags[0]))
|
||||
assert hash(new_kfrag) == hash(kfrags[0])
|
||||
|
||||
|
||||
def test_kfrag_str(kfrags):
|
||||
s = str(kfrags[0])
|
||||
assert "KeyFrag" in s
|
||||
|
||||
|
||||
WRONG_PARAMETERS = (
|
||||
# (num_kfrags, threshold)
|
||||
(-1, -1), (-1, 0), (-1, 5),
|
||||
(0, -1), (0, 0), (0, 5),
|
||||
(1, -1), (1, 0), (1, 5),
|
||||
(5, -1), (5, 0), (5, 10)
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("num_kfrags, threshold", WRONG_PARAMETERS)
|
||||
def test_wrong_threshold_and_num_kfrags(num_kfrags, threshold, alices_keys, bobs_keys):
|
||||
|
||||
delegating_sk, signing_sk = alices_keys
|
||||
_receiving_sk, receiving_pk = bobs_keys
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generate_kfrags(delegating_sk=delegating_sk,
|
||||
signing_sk=signing_sk,
|
||||
receiving_pk=receiving_pk,
|
||||
threshold=threshold,
|
||||
num_kfrags=num_kfrags)
|
|
@ -0,0 +1,203 @@
|
|||
import os
|
||||
import string
|
||||
|
||||
import pytest
|
||||
|
||||
from umbral.keys import PublicKey, SecretKey, SecretKeyFactory, Signature
|
||||
from umbral.hashing import Hash
|
||||
|
||||
|
||||
def test_gen_key():
|
||||
sk = SecretKey.random()
|
||||
assert type(sk) == SecretKey
|
||||
|
||||
pk = PublicKey.from_secret_key(sk)
|
||||
assert type(pk) == PublicKey
|
||||
|
||||
pk2 = PublicKey.from_secret_key(sk)
|
||||
assert pk == pk2
|
||||
|
||||
|
||||
def test_derive_key_from_label():
|
||||
factory = SecretKeyFactory.random()
|
||||
|
||||
label = b"my_healthcare_information"
|
||||
|
||||
sk1 = factory.secret_key_by_label(label)
|
||||
assert type(sk1) == SecretKey
|
||||
|
||||
pk1 = PublicKey.from_secret_key(sk1)
|
||||
assert type(pk1) == PublicKey
|
||||
|
||||
# Check that key derivation is reproducible
|
||||
sk2 = factory.secret_key_by_label(label)
|
||||
pk2 = PublicKey.from_secret_key(sk2)
|
||||
assert sk1 == sk2
|
||||
assert pk1 == pk2
|
||||
|
||||
# Different labels on the same master secret create different keys
|
||||
label = b"my_tax_information"
|
||||
sk3 = factory.secret_key_by_label(label)
|
||||
pk3 = PublicKey.from_secret_key(sk3)
|
||||
assert sk1 != sk3
|
||||
|
||||
|
||||
def test_secret_key_serialization():
|
||||
sk = SecretKey.random()
|
||||
encoded_key = bytes(sk)
|
||||
decoded_key = SecretKey.from_bytes(encoded_key)
|
||||
assert sk == decoded_key
|
||||
|
||||
|
||||
def test_secret_key_str():
|
||||
sk = SecretKey.random()
|
||||
s = str(sk)
|
||||
assert s == "SecretKey:..."
|
||||
|
||||
|
||||
def test_secret_key_hash():
|
||||
sk = SecretKey.random()
|
||||
# Insecure Python hash, shouldn't be available.
|
||||
with pytest.raises(NotImplementedError):
|
||||
hash(sk)
|
||||
|
||||
|
||||
def test_secret_key_factory_str():
|
||||
skf = SecretKeyFactory.random()
|
||||
s = str(skf)
|
||||
assert s == "SecretKeyFactory:..."
|
||||
|
||||
|
||||
def test_secret_key_factory_hash():
|
||||
skf = SecretKeyFactory.random()
|
||||
# Insecure Python hash, shouldn't be available.
|
||||
with pytest.raises(NotImplementedError):
|
||||
hash(skf)
|
||||
|
||||
|
||||
def test_public_key_serialization():
|
||||
sk = SecretKey.random()
|
||||
pk = PublicKey.from_secret_key(sk)
|
||||
|
||||
encoded_key = bytes(pk)
|
||||
decoded_key = PublicKey.from_bytes(encoded_key)
|
||||
assert pk == decoded_key
|
||||
|
||||
|
||||
def test_public_key_point():
|
||||
pk = PublicKey.from_secret_key(SecretKey.random())
|
||||
assert bytes(pk) == bytes(pk.point())
|
||||
|
||||
|
||||
def test_public_key_str():
|
||||
pk = PublicKey.from_secret_key(SecretKey.random())
|
||||
s = str(pk)
|
||||
assert 'PublicKey' in s
|
||||
|
||||
|
||||
def test_keying_material_serialization():
|
||||
factory = SecretKeyFactory.random()
|
||||
|
||||
encoded_factory = bytes(factory)
|
||||
decoded_factory = SecretKeyFactory.from_bytes(encoded_factory)
|
||||
|
||||
label = os.urandom(32)
|
||||
sk1 = factory.secret_key_by_label(label)
|
||||
sk2 = decoded_factory.secret_key_by_label(label)
|
||||
assert sk1 == sk2
|
||||
|
||||
|
||||
def test_public_key_is_hashable():
|
||||
sk = SecretKey.random()
|
||||
pk = PublicKey.from_secret_key(sk)
|
||||
|
||||
sk2 = SecretKey.random()
|
||||
pk2 = PublicKey.from_secret_key(sk2)
|
||||
assert hash(pk) != hash(pk2)
|
||||
|
||||
pk3 = PublicKey.from_bytes(bytes(pk))
|
||||
assert hash(pk) == hash(pk3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('execution_number', range(20)) # Run this test 20 times.
|
||||
def test_sign_and_verify(execution_number):
|
||||
sk = SecretKey.random()
|
||||
pk = PublicKey.from_secret_key(sk)
|
||||
|
||||
message = b"peace at dawn"
|
||||
dst = b"dst"
|
||||
|
||||
digest = Hash(dst)
|
||||
digest.update(message)
|
||||
signature = sk.sign_digest(digest)
|
||||
|
||||
digest = Hash(dst)
|
||||
digest.update(message)
|
||||
assert signature.verify_digest(pk, digest)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('execution_number', range(20)) # Run this test 20 times.
|
||||
def test_sign_serialize_and_verify(execution_number):
|
||||
sk = SecretKey.random()
|
||||
pk = PublicKey.from_secret_key(sk)
|
||||
|
||||
message = b"peace at dawn"
|
||||
dst = b"dst"
|
||||
|
||||
digest = Hash(dst)
|
||||
digest.update(message)
|
||||
signature = sk.sign_digest(digest)
|
||||
|
||||
signature_bytes = bytes(signature)
|
||||
signature_restored = Signature.from_bytes(signature_bytes)
|
||||
|
||||
digest = Hash(dst)
|
||||
digest.update(message)
|
||||
assert signature_restored.verify_digest(pk, digest)
|
||||
|
||||
|
||||
def test_verification_fail():
|
||||
sk = SecretKey.random()
|
||||
pk = PublicKey.from_secret_key(sk)
|
||||
|
||||
message = b"peace at dawn"
|
||||
dst = b"dst"
|
||||
|
||||
digest = Hash(dst)
|
||||
digest.update(message)
|
||||
signature = sk.sign_digest(digest)
|
||||
|
||||
# wrong DST
|
||||
digest = Hash(b"other dst")
|
||||
digest.update(message)
|
||||
assert not signature.verify_digest(pk, digest)
|
||||
|
||||
# wrong message
|
||||
digest = Hash(dst)
|
||||
digest.update(b"no peace at dawn")
|
||||
assert not signature.verify_digest(pk, digest)
|
||||
|
||||
# bad signature
|
||||
signature_bytes = bytes(signature)
|
||||
signature_bytes = b'\x00' + signature_bytes[1:]
|
||||
signature_restored = Signature.from_bytes(signature_bytes)
|
||||
|
||||
digest = Hash(dst)
|
||||
digest.update(message)
|
||||
assert not signature_restored.verify_digest(pk, digest)
|
||||
|
||||
|
||||
def test_signature_repr():
|
||||
|
||||
sk = SecretKey.random()
|
||||
pk = PublicKey.from_secret_key(sk)
|
||||
|
||||
message = b"peace at dawn"
|
||||
dst = b"dst"
|
||||
|
||||
digest = Hash(dst)
|
||||
digest.update(message)
|
||||
signature = sk.sign_digest(digest)
|
||||
|
||||
s = repr(signature)
|
||||
assert 'Signature' in s
|
|
@ -0,0 +1,95 @@
|
|||
import pytest
|
||||
|
||||
from umbral import (
|
||||
SecretKey,
|
||||
PublicKey,
|
||||
encrypt,
|
||||
generate_kfrags,
|
||||
decrypt_original,
|
||||
reencrypt,
|
||||
decrypt_reencrypted,
|
||||
)
|
||||
from umbral.dem import ErrorInvalidTag
|
||||
|
||||
|
||||
def test_public_key_encryption(alices_keys):
|
||||
delegating_sk, _ = alices_keys
|
||||
delegating_pk = PublicKey.from_secret_key(delegating_sk)
|
||||
plaintext = b'peace at dawn'
|
||||
capsule, ciphertext = encrypt(delegating_pk, plaintext)
|
||||
plaintext_decrypted = decrypt_original(delegating_sk, capsule, ciphertext)
|
||||
assert plaintext == plaintext_decrypted
|
||||
|
||||
# Wrong secret key
|
||||
sk = SecretKey.random()
|
||||
with pytest.raises(ErrorInvalidTag):
|
||||
decrypt_original(sk, capsule, ciphertext)
|
||||
|
||||
|
||||
SIMPLE_API_PARAMETERS = (
|
||||
# (num_kfrags, threshold)
|
||||
(1, 1),
|
||||
(6, 1),
|
||||
(6, 4),
|
||||
(6, 6),
|
||||
(50, 30)
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("num_kfrags, threshold", SIMPLE_API_PARAMETERS)
|
||||
def test_simple_api(num_kfrags, threshold):
|
||||
"""
|
||||
This test models the main interactions between actors (i.e., Alice,
|
||||
Bob, Data Source, and Ursulas) and artifacts (i.e., public and private keys,
|
||||
ciphertexts, capsules, KFrags, CFrags, etc).
|
||||
|
||||
The test covers all the main stages of data sharing:
|
||||
key generation, delegation, encryption, decryption by
|
||||
Alice, re-encryption by Ursula, and decryption by Bob.
|
||||
"""
|
||||
|
||||
# Key Generation (Alice)
|
||||
delegating_sk = SecretKey.random()
|
||||
delegating_pk = PublicKey.from_secret_key(delegating_sk)
|
||||
|
||||
signing_sk = SecretKey.random()
|
||||
signing_pk = PublicKey.from_secret_key(signing_sk)
|
||||
|
||||
# Key Generation (Bob)
|
||||
receiving_sk = SecretKey.random()
|
||||
receiving_pk = PublicKey.from_secret_key(receiving_sk)
|
||||
|
||||
# Encryption by an unnamed data source
|
||||
plaintext = b'peace at dawn'
|
||||
capsule, ciphertext = encrypt(delegating_pk, plaintext)
|
||||
|
||||
# Decryption by Alice
|
||||
plaintext_decrypted = decrypt_original(delegating_sk, capsule, ciphertext)
|
||||
assert plaintext_decrypted == plaintext
|
||||
|
||||
# Split Re-Encryption Key Generation (aka Delegation)
|
||||
kfrags = generate_kfrags(delegating_sk, receiving_pk, signing_sk, threshold, num_kfrags)
|
||||
|
||||
# Bob requests re-encryption to some set of M ursulas
|
||||
cfrags = list()
|
||||
for kfrag in kfrags[:threshold]:
|
||||
# Ursula checks that the received kfrag is valid
|
||||
assert kfrag.verify(signing_pk, delegating_pk, receiving_pk)
|
||||
|
||||
# Re-encryption by an Ursula
|
||||
cfrag = reencrypt(capsule, kfrag)
|
||||
|
||||
# Bob collects the result
|
||||
cfrags.append(cfrag)
|
||||
|
||||
# Bob checks that the received cfrags are valid
|
||||
assert all(cfrag.verify(capsule, delegating_pk, receiving_pk, signing_pk) for cfrag in cfrags)
|
||||
|
||||
# Decryption by Bob
|
||||
plaintext_reenc = decrypt_reencrypted(receiving_sk,
|
||||
delegating_pk,
|
||||
capsule,
|
||||
cfrags[:threshold],
|
||||
ciphertext,
|
||||
)
|
||||
|
||||
assert plaintext_reenc == plaintext
|
|
@ -0,0 +1,92 @@
|
|||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
from umbral.serializable import Serializable, serialize_bool, take_bool
|
||||
|
||||
|
||||
class A(Serializable):
|
||||
|
||||
def __init__(self, val: int):
|
||||
assert 0 <= val < 2**32
|
||||
self.val = val
|
||||
|
||||
@classmethod
|
||||
def __take__(cls, data):
|
||||
val_bytes, data = cls.__take_bytes__(data, 4)
|
||||
return cls(int.from_bytes(val_bytes, byteorder='big')), data
|
||||
|
||||
def __bytes__(self):
|
||||
return self.val.to_bytes(4, byteorder='big')
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, A) and self.val == other.val
|
||||
|
||||
|
||||
class B(Serializable):
|
||||
|
||||
def __init__(self, val: int):
|
||||
assert 0 <= val < 2**16
|
||||
self.val = val
|
||||
|
||||
@classmethod
|
||||
def __take__(cls, data):
|
||||
val_bytes, data = cls.__take_bytes__(data, 2)
|
||||
return cls(int.from_bytes(val_bytes, byteorder='big')), data
|
||||
|
||||
def __bytes__(self):
|
||||
return self.val.to_bytes(2, byteorder='big')
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, B) and self.val == other.val
|
||||
|
||||
|
||||
class C(Serializable):
|
||||
|
||||
def __init__(self, a: A, b: B):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
@classmethod
|
||||
def __take__(cls, data):
|
||||
components, data = cls.__take_types__(data, A, B)
|
||||
return cls(*components), data
|
||||
|
||||
def __bytes__(self):
|
||||
return bytes(self.a) + bytes(self.b)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, C) and self.a == other.a and self.b == other.b
|
||||
|
||||
|
||||
def test_normal_operation():
|
||||
a = A(2**32 - 123)
|
||||
b = B(2**16 - 456)
|
||||
c = C(a, b)
|
||||
c_back = C.from_bytes(bytes(c))
|
||||
assert c_back == c
|
||||
|
||||
|
||||
def test_too_many_bytes():
|
||||
a = A(2**32 - 123)
|
||||
b = B(2**16 - 456)
|
||||
c = C(a, b)
|
||||
with pytest.raises(ValueError, match="1 bytes remaining after deserializing"):
|
||||
C.from_bytes(bytes(c) + b'\x00')
|
||||
|
||||
|
||||
def test_not_enough_bytes():
|
||||
a = A(2**32 - 123)
|
||||
b = B(2**16 - 456)
|
||||
c = C(a, b)
|
||||
# Will happen on deserialization of B - 1 byte missing
|
||||
with pytest.raises(ValueError, match="cannot take 2 bytes from a bytestring of size 1"):
|
||||
C.from_bytes(bytes(c)[:-1])
|
||||
|
||||
|
||||
def test_serialize_bool():
|
||||
assert take_bool(serialize_bool(True) + b'1234') == (True, b'1234')
|
||||
assert take_bool(serialize_bool(False) + b'12') == (False, b'12')
|
||||
error_msg = re.escape("Incorrectly serialized boolean; expected b'\\x00' or b'\\x01', got b'z'")
|
||||
with pytest.raises(ValueError, match=error_msg):
|
||||
take_bool(b'z1234')
|
|
@ -4,11 +4,13 @@ from typing import Optional
|
|||
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
|
||||
import nacl
|
||||
from nacl.bindings.crypto_aead import (
|
||||
crypto_aead_xchacha20poly1305_ietf_encrypt as xchacha_encrypt,
|
||||
crypto_aead_xchacha20poly1305_ietf_decrypt as xchacha_decrypt,
|
||||
crypto_aead_xchacha20poly1305_ietf_KEYBYTES as XCHACHA_KEY_SIZE,
|
||||
crypto_aead_xchacha20poly1305_ietf_NPUBBYTES as XCHACHA_NONCE_SIZE,
|
||||
crypto_aead_xchacha20poly1305_ietf_ABYTES as XCHACHA_TAG_SIZE,
|
||||
)
|
||||
|
||||
from . import openssl
|
||||
|
@ -28,10 +30,15 @@ def kdf(data: bytes,
|
|||
return hkdf.derive(data)
|
||||
|
||||
|
||||
class ErrorInvalidTag(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DEM:
|
||||
|
||||
KEY_SIZE = XCHACHA_KEY_SIZE
|
||||
NONCE_SIZE = XCHACHA_NONCE_SIZE
|
||||
TAG_SIZE = XCHACHA_TAG_SIZE
|
||||
|
||||
def __init__(self,
|
||||
key_material: bytes,
|
||||
|
@ -53,5 +60,11 @@ class DEM:
|
|||
nonce = nonce_and_ciphertext[:self.NONCE_SIZE]
|
||||
ciphertext = nonce_and_ciphertext[self.NONCE_SIZE:]
|
||||
|
||||
# TODO: replace `nacl.exceptions.CryptoError` with our error?
|
||||
return xchacha_decrypt(ciphertext, authenticated_data, nonce, self._key)
|
||||
# Prevent an out of bounds error deep in NaCl
|
||||
if len(ciphertext) < self.TAG_SIZE:
|
||||
raise ValueError(f"The authentication tag is missing or malformed")
|
||||
|
||||
try:
|
||||
return xchacha_decrypt(ciphertext, authenticated_data, nonce, self._key)
|
||||
except nacl.exceptions.CryptoError:
|
||||
raise ErrorInvalidTag
|
||||
|
|
Loading…
Reference in New Issue