Add tests

pull/263/head
Bogdan Opanchuk 2021-03-19 20:21:01 -07:00
parent f58a2580dc
commit c401c52e92
13 changed files with 1453 additions and 2 deletions

51
tests/conftest.py Normal file
View File

@ -0,0 +1,51 @@
import pytest
from umbral import SecretKey, PublicKey, generate_kfrags, encrypt
@pytest.fixture
def alices_keys():
delegating_sk = SecretKey.random()
signing_sk = SecretKey.random()
return delegating_sk, signing_sk
@pytest.fixture
def bobs_keys():
sk = SecretKey.random()
pk = PublicKey.from_secret_key(sk)
return sk, pk
@pytest.fixture
def kfrags(alices_keys, bobs_keys):
delegating_sk, signing_sk = alices_keys
receiving_sk, receiving_pk = bobs_keys
yield generate_kfrags(delegating_sk=delegating_sk,
signing_sk=signing_sk,
receiving_pk=receiving_pk,
threshold=6, num_kfrags=10)
@pytest.fixture(scope='session')
def message():
message = (b"dnunez [9:30 AM]"
b"@Tux we had this super fruitful discussion last night with @jMyles @michwill @KPrasch"
b"to sum up: the symmetric ciphertext is now called the 'Chimney'."
b"the chimney of the capsule, of course"
b"tux [9:32 AM]"
b"wat")
return message
@pytest.fixture
def capsule_and_ciphertext(alices_keys, message):
delegating_sk, _signing_sk = alices_keys
capsule, ciphertext = encrypt(PublicKey.from_secret_key(delegating_sk), message)
return capsule, ciphertext
@pytest.fixture
def capsule(capsule_and_ciphertext):
capsule, ciphertext = capsule_and_ciphertext
return capsule

107
tests/test_capsule.py Normal file
View File

@ -0,0 +1,107 @@
import pytest
from umbral import (
Capsule,
SecretKey,
PublicKey,
encrypt,
decrypt_original,
reencrypt,
decrypt_reencrypted,
generate_kfrags
)
from umbral.curve_point import CurvePoint
def test_capsule_serialization(alices_keys):
delegating_sk, _signing_sk = alices_keys
delegating_pk = PublicKey.from_secret_key(delegating_sk)
capsule, _key = Capsule.from_public_key(delegating_pk)
new_capsule = Capsule.from_bytes(bytes(capsule))
assert capsule == new_capsule
# Deserializing a bad capsule triggers verification error
capsule.point_e = CurvePoint.random()
capsule_bytes = bytes(capsule)
with pytest.raises(Capsule.NotValid):
Capsule.from_bytes(capsule_bytes)
def test_capsule_is_hashable(alices_keys):
delegating_sk, _signing_sk = alices_keys
delegating_pk = PublicKey.from_secret_key(delegating_sk)
capsule1, key1 = Capsule.from_public_key(delegating_pk)
capsule2, key2 = Capsule.from_public_key(delegating_pk)
assert capsule1 != capsule2
assert key1 != key2
assert hash(capsule1) != hash(capsule2)
new_capsule = Capsule.from_bytes(bytes(capsule1))
assert hash(new_capsule) == hash(capsule1)
def test_open_original(alices_keys):
delegating_sk, _signing_sk = alices_keys
delegating_pk = PublicKey.from_secret_key(delegating_sk)
capsule, key = Capsule.from_public_key(delegating_pk)
key_back = capsule.open_original(delegating_sk)
assert key == key_back
def test_open_reencrypted(alices_keys, bobs_keys):
threshold = 6
num_kfrags = 10
delegating_sk, signing_sk = alices_keys
receiving_sk, receiving_pk = bobs_keys
signing_pk = PublicKey.from_secret_key(signing_sk)
delegating_pk = PublicKey.from_secret_key(delegating_sk)
capsule, key = Capsule.from_public_key(delegating_pk)
kfrags = generate_kfrags(delegating_sk=delegating_sk,
signing_sk=signing_sk,
receiving_pk=receiving_pk,
threshold=threshold,
num_kfrags=num_kfrags)
cfrags = [reencrypt(capsule, kfrag) for kfrag in kfrags]
key_back = capsule.open_reencrypted(receiving_sk, delegating_pk, cfrags[:threshold])
assert key_back == key
# No cfrags at all
with pytest.raises(ValueError, match="Empty CapsuleFrag sequence"):
capsule.open_reencrypted(receiving_sk, delegating_pk, [])
# Not enough cfrags
with pytest.raises(ValueError, match="Internal validation failed"):
capsule.open_reencrypted(receiving_sk, delegating_pk, cfrags[:threshold-1])
# Repeating cfrags
with pytest.raises(ValueError, match="Some of the CapsuleFrags are repeated"):
capsule.open_reencrypted(receiving_sk, delegating_pk, [cfrags[0]] + cfrags[:threshold-1])
# Mismatched cfrags
kfrags2 = generate_kfrags(delegating_sk=delegating_sk,
signing_sk=signing_sk,
receiving_pk=receiving_pk,
threshold=threshold,
num_kfrags=num_kfrags)
cfrags2 = [reencrypt(capsule, kfrag) for kfrag in kfrags2]
with pytest.raises(ValueError, match="CapsuleFrags are not pairwise consistent"):
capsule.open_reencrypted(receiving_sk, delegating_pk, [cfrags2[0]] + cfrags[:threshold-1])
def test_capsule_str(capsule):
s = str(capsule)
assert 'Capsule' in s

149
tests/test_capsule_frag.py Normal file
View File

@ -0,0 +1,149 @@
from umbral import reencrypt, CapsuleFrag, PublicKey, Capsule
from umbral.curve_point import CurvePoint
def test_cfrag_serialization(alices_keys, bobs_keys, capsule, kfrags):
delegating_sk, signing_sk = alices_keys
_receiving_sk, receiving_pk = bobs_keys
signing_pk = PublicKey.from_secret_key(signing_sk)
delegating_pk = PublicKey.from_secret_key(delegating_sk)
metadata = b'This is an example of metadata for re-encryption request'
for kfrag in kfrags:
cfrag = reencrypt(capsule, kfrag, metadata=metadata)
cfrag_bytes = bytes(cfrag)
new_cfrag = CapsuleFrag.from_bytes(cfrag_bytes)
assert new_cfrag == cfrag
assert new_cfrag.verify(capsule,
delegating_pk=delegating_pk,
receiving_pk=receiving_pk,
signing_pk=signing_pk,
metadata=metadata)
# No metadata
assert not new_cfrag.verify(capsule,
delegating_pk=delegating_pk,
receiving_pk=receiving_pk,
signing_pk=signing_pk)
# Wrong metadata
assert not new_cfrag.verify(capsule,
delegating_pk=delegating_pk,
receiving_pk=receiving_pk,
signing_pk=signing_pk,
metadata=b'Not the same metadata')
# Wrong delegating key
assert not new_cfrag.verify(capsule,
delegating_pk=receiving_pk,
receiving_pk=receiving_pk,
signing_pk=signing_pk,
metadata=metadata)
# Wrong receiving key
assert not new_cfrag.verify(capsule,
delegating_pk=delegating_pk,
receiving_pk=delegating_pk,
signing_pk=signing_pk,
metadata=metadata)
# Wrong signing key
assert not new_cfrag.verify(capsule,
delegating_pk=delegating_pk,
receiving_pk=receiving_pk,
signing_pk=receiving_pk,
metadata=metadata)
def test_cfrag_serialization_no_metadata(alices_keys, bobs_keys, capsule, kfrags):
delegating_sk, signing_sk = alices_keys
_receiving_sk, receiving_pk = bobs_keys
signing_pk = PublicKey.from_secret_key(signing_sk)
delegating_pk = PublicKey.from_secret_key(delegating_sk)
for kfrag in kfrags:
# Create with no metadata
cfrag = reencrypt(capsule, kfrag)
cfrag_bytes = bytes(cfrag)
new_cfrag = CapsuleFrag.from_bytes(cfrag_bytes)
assert new_cfrag.verify(capsule,
delegating_pk=delegating_pk,
receiving_pk=receiving_pk,
signing_pk=signing_pk)
assert not new_cfrag.verify(capsule,
delegating_pk=delegating_pk,
receiving_pk=receiving_pk,
signing_pk=signing_pk,
metadata=b'some metadata')
def test_cfrag_with_wrong_capsule(alices_keys, bobs_keys,
kfrags, capsule_and_ciphertext, message):
capsule, ciphertext = capsule_and_ciphertext
delegating_sk, signing_sk = alices_keys
delegating_pk = PublicKey.from_secret_key(delegating_sk)
_receiving_sk, receiving_pk = bobs_keys
capsule_alice1 = capsule
capsule_alice2, _unused_key2 = Capsule.from_public_key(delegating_pk)
metadata = b"some metadata"
cfrag = reencrypt(capsule_alice2, kfrags[0], metadata=metadata)
assert not cfrag.verify(capsule_alice1,
delegating_pk=delegating_pk,
receiving_pk=receiving_pk,
signing_pk=PublicKey.from_secret_key(signing_sk),
metadata=metadata)
def test_cfrag_with_wrong_data(kfrags, alices_keys, bobs_keys, capsule_and_ciphertext, message):
capsule, ciphertext = capsule_and_ciphertext
delegating_sk, signing_sk = alices_keys
delegating_pk = PublicKey.from_secret_key(delegating_sk)
_receiving_sk, receiving_pk = bobs_keys
metadata = b"some metadata"
cfrag = reencrypt(capsule, kfrags[0], metadata=metadata)
# Let's put random garbage in one of the cfrags
cfrag.point_e1 = CurvePoint.random()
cfrag.point_v1 = CurvePoint.random()
assert not cfrag.verify(capsule,
delegating_pk=delegating_pk,
receiving_pk=receiving_pk,
signing_pk=PublicKey.from_secret_key(signing_sk),
metadata=metadata)
def test_cfrag_is_hashable(capsule, kfrags):
cfrag0 = reencrypt(capsule, kfrags[0], metadata=b'abcdef')
cfrag1 = reencrypt(capsule, kfrags[1], metadata=b'abcdef')
assert hash(cfrag0) != hash(cfrag1)
new_cfrag = CapsuleFrag.from_bytes(bytes(cfrag0))
assert hash(new_cfrag) == hash(cfrag0)
def test_cfrag_str(capsule, kfrags):
cfrag0 = reencrypt(capsule, kfrags[0], metadata=b'abcdef')
s = str(cfrag0)
assert 'CapsuleFrag' in s

205
tests/test_compatibility.py Normal file
View File

@ -0,0 +1,205 @@
import pytest
try:
import umbral_pre as umbral_rs
except ImportError:
umbral_rs = None
import umbral as umbral_py
def pytest_generate_tests(metafunc):
if 'implementations' in metafunc.fixturenames:
implementations = [(umbral_py, umbral_py)]
ids = ['python -> python']
if umbral_rs is not None:
implementations.extend([(umbral_py, umbral_rs), (umbral_rs, umbral_py)])
ids.extend(['python -> rust', 'rust -> python'])
metafunc.parametrize('implementations', implementations, ids=ids)
def _create_keypair(umbral):
sk = umbral.SecretKey.random()
pk = umbral.PublicKey.from_secret_key(sk)
return bytes(sk), bytes(pk)
def _restore_keys(umbral, sk_bytes, pk_bytes):
sk = umbral.SecretKey.from_bytes(sk_bytes)
pk_from_sk = umbral.PublicKey.from_secret_key(sk)
pk_from_bytes = umbral.PublicKey.from_bytes(pk_bytes)
assert pk_from_sk == pk_from_bytes
def test_keys(implementations):
umbral1, umbral2 = implementations
# On client 1
sk_bytes, pk_bytes = _create_keypair(umbral1)
# On client 2
_restore_keys(umbral2, sk_bytes, pk_bytes)
def _create_sk_factory_and_sk(umbral, label):
skf = umbral.SecretKeyFactory.random()
sk = skf.secret_key_by_label(label)
return bytes(skf), bytes(sk)
def _check_sk_is_same(umbral, label, skf_bytes, sk_bytes):
skf = umbral.SecretKeyFactory.from_bytes(skf_bytes)
sk_restored = umbral.SecretKey.from_bytes(sk_bytes)
sk_generated = skf.secret_key_by_label(label)
assert sk_restored == sk_generated
def test_secret_key_factory(implementations):
umbral1, umbral2 = implementations
label = b'label'
skf_bytes, sk_bytes = _create_sk_factory_and_sk(umbral1, label)
_check_sk_is_same(umbral2, label, skf_bytes, sk_bytes)
def _encrypt(umbral, plaintext, pk_bytes):
pk = umbral.PublicKey.from_bytes(pk_bytes)
capsule, ciphertext = umbral.encrypt(pk, plaintext)
return bytes(capsule), ciphertext
def _decrypt_original(umbral, sk_bytes, capsule_bytes, ciphertext):
capsule = umbral.Capsule.from_bytes(bytes(capsule_bytes))
sk = umbral.SecretKey.from_bytes(sk_bytes)
return umbral.decrypt_original(sk, capsule, ciphertext)
def test_encrypt_decrypt(implementations):
umbral1, umbral2 = implementations
plaintext = b'peace at dawn'
# On client 1
sk_bytes, pk_bytes = _create_keypair(umbral1)
# On client 2
capsule_bytes, ciphertext = _encrypt(umbral2, plaintext, pk_bytes)
# On client 1
plaintext_decrypted = _decrypt_original(umbral1, sk_bytes, capsule_bytes, ciphertext)
assert plaintext_decrypted == plaintext
def _generate_kfrags(umbral, delegating_sk_bytes, receiving_pk_bytes,
signing_sk_bytes, threshold, num_frags):
delegating_sk = umbral.SecretKey.from_bytes(delegating_sk_bytes)
receiving_pk = umbral.PublicKey.from_bytes(receiving_pk_bytes)
signing_sk = umbral.SecretKey.from_bytes(signing_sk_bytes)
kfrags = umbral.generate_kfrags(delegating_sk,
receiving_pk,
signing_sk,
threshold,
num_frags,
True,
True,
)
return [bytes(kfrag) for kfrag in kfrags]
def _verify_kfrags(umbral, kfrags_bytes, signing_pk_bytes, delegating_pk_bytes, receiving_pk_bytes):
kfrags = [umbral.KeyFrag.from_bytes(kfrag_bytes) for kfrag_bytes in kfrags_bytes]
signing_pk = umbral.PublicKey.from_bytes(signing_pk_bytes)
delegating_pk = umbral.PublicKey.from_bytes(delegating_pk_bytes)
receiving_pk = umbral.PublicKey.from_bytes(receiving_pk_bytes)
assert all(kfrag.verify(signing_pk, delegating_pk, receiving_pk) for kfrag in kfrags)
def test_kfrags(implementations):
umbral1, umbral2 = implementations
threshold = 2
num_frags = 3
plaintext = b'peace at dawn'
# On client 1
receiving_sk_bytes, receiving_pk_bytes = _create_keypair(umbral1)
delegating_sk_bytes, delegating_pk_bytes = _create_keypair(umbral1)
signing_sk_bytes, signing_pk_bytes = _create_keypair(umbral1)
kfrags_bytes = _generate_kfrags(umbral1, delegating_sk_bytes, receiving_pk_bytes,
signing_sk_bytes, threshold, num_frags)
# On client 2
_verify_kfrags(umbral2, kfrags_bytes, signing_pk_bytes, delegating_pk_bytes, receiving_pk_bytes)
def _reencrypt(umbral, capsule_bytes, kfrags_bytes, threshold, metadata):
capsule = umbral.Capsule.from_bytes(bytes(capsule_bytes))
kfrags = [umbral.KeyFrag.from_bytes(kfrag_bytes) for kfrag_bytes in kfrags_bytes]
cfrags = [umbral.reencrypt(capsule, kfrag, metadata=metadata) for kfrag in kfrags[:threshold]]
return [bytes(cfrag) for cfrag in cfrags]
def _decrypt_reencrypted(umbral, receiving_sk_bytes, delegating_pk_bytes, signing_pk_bytes,
capsule_bytes, cfrags_bytes, ciphertext, metadata):
receiving_sk = umbral.SecretKey.from_bytes(receiving_sk_bytes)
receiving_pk = umbral.PublicKey.from_secret_key(receiving_sk)
delegating_pk = umbral.PublicKey.from_bytes(delegating_pk_bytes)
signing_pk = umbral.PublicKey.from_bytes(signing_pk_bytes)
capsule = umbral.Capsule.from_bytes(bytes(capsule_bytes))
cfrags = [umbral.CapsuleFrag.from_bytes(cfrag_bytes) for cfrag_bytes in cfrags_bytes]
assert all(cfrag.verify(capsule, delegating_pk, receiving_pk, signing_pk, metadata=metadata)
for cfrag in cfrags)
# Decryption by Bob
plaintext = umbral.decrypt_reencrypted(receiving_sk,
delegating_pk,
capsule,
cfrags,
ciphertext,
)
return plaintext
def test_reencrypt(implementations):
umbral1, umbral2 = implementations
metadata = b'metadata'
threshold = 2
num_frags = 3
plaintext = b'peace at dawn'
# On client 1
receiving_sk_bytes, receiving_pk_bytes = _create_keypair(umbral1)
delegating_sk_bytes, delegating_pk_bytes = _create_keypair(umbral1)
signing_sk_bytes, signing_pk_bytes = _create_keypair(umbral1)
capsule_bytes, ciphertext = _encrypt(umbral1, plaintext, delegating_pk_bytes)
kfrags_bytes = _generate_kfrags(umbral1, delegating_sk_bytes, receiving_pk_bytes,
signing_sk_bytes, threshold, num_frags)
# On client 2
cfrags_bytes = _reencrypt(umbral2, capsule_bytes, kfrags_bytes, threshold, metadata)
# On client 1
plaintext_reencrypted = _decrypt_reencrypted(umbral1,
receiving_sk_bytes, delegating_pk_bytes, signing_pk_bytes,
capsule_bytes, cfrags_bytes, ciphertext, metadata)
assert plaintext_reencrypted == plaintext

113
tests/test_curve.py Normal file
View File

@ -0,0 +1,113 @@
import pytest
from umbral.openssl import Curve, bn_to_int, point_to_affine_coords
from umbral.curve import CURVE, CURVES, SECP256R1, SECP256K1, SECP384R1
def test_supported_curves():
# Ensure we have the correct number of supported curves hardcoded
number_of_supported_curves = 3
assert len(Curve._supported_curves) == number_of_supported_curves
# Manually ensure the `_supported curves` dict contains only valid supported curves
assert Curve._supported_curves[415] == 'secp256r1'
assert Curve._supported_curves[714] == 'secp256k1'
assert Curve._supported_curves[715] == 'secp384r1'
def test_create_by_nid():
nid, name = 714, 'secp256k1'
# supported
_curve_714 = Curve(nid=nid)
assert _curve_714.nid == nid
assert _curve_714.name == name
# unsuported
with pytest.raises(NotImplementedError):
Curve(711)
def test_create_by_name():
nid, name = 714, 'secp256k1'
# Supported
_curve_secp256k1 = Curve.from_name(name)
assert _curve_secp256k1.name == name
assert _curve_secp256k1.nid == nid
# Unsupported
with pytest.raises(NotImplementedError):
Curve.from_name('abcd123e4')
def test_curve_constants():
test_p256 = SECP256R1
test_secp256k1 = SECP256K1
test_p384 = SECP384R1
assert CURVE == SECP256K1
# Test the hardcoded curve NIDs are correct:
assert test_p256.nid == 415
assert test_secp256k1.nid == 714
assert test_p384.nid == 715
# Ensure every curve constant is in the CURVES collection
number_of_supported_curves = 3
assert len(CURVES) == number_of_supported_curves
# Ensure all supported curves can be initialized
for nid, name in Curve._supported_curves.items():
by_nid, by_name = Curve(nid=nid), Curve.from_name(name)
assert by_nid.name == name
assert by_name.nid == nid
def test_curve_str():
for curve in CURVES:
s = str(curve)
assert str(curve.nid) in s
assert str(curve.name) in s
def _curve_info(curve: Curve):
assert bn_to_int(curve.bn_order) == curve.order
return dict(order=curve.order,
field_element_size=curve.field_element_size,
scalar_size=curve.scalar_size,
generator=point_to_affine_coords(curve, curve.point_generator))
def test_secp256k1():
info = _curve_info(SECP256K1)
assert info['order'] == 0xFFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFE_BAAEDCE6_AF48A03B_BFD25E8C_D0364141
assert info['field_element_size'] == 32
assert info['scalar_size'] == 32
assert info['generator'] == (
0x79BE667E_F9DCBBAC_55A06295_CE870B07_029BFCDB_2DCE28D9_59F2815B_16F81798,
0x483ADA77_26A3C465_5DA4FBFC_0E1108A8_FD17B448_A6855419_9C47D08F_FB10D4B8)
def test_p256():
info = _curve_info(SECP256R1)
assert info['order'] == 0xFFFFFFFF_00000000_FFFFFFFF_FFFFFFFF_BCE6FAAD_A7179E84_F3B9CAC2_FC632551
assert info['field_element_size'] == 32
assert info['scalar_size'] == 32
assert info['generator'] == (
0x6B17D1F2_E12C4247_F8BCE6E5_63A440F2_77037D81_2DEB33A0_F4A13945_D898C296,
0x4FE342E2_FE1A7F9B_8EE7EB4A_7C0F9E16_2BCE3357_6B315ECE_CBB64068_37BF51F5)
def test_p384():
info = _curve_info(SECP384R1)
assert info['order'] == 0xFFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_C7634D81_F4372DDF_581A0DB2_48B0A77A_ECEC196A_CCC52973
assert info['field_element_size'] == 48
assert info['scalar_size'] == 48
assert info['generator'] == (
0xAA87CA22_BE8B0537_8EB1C71E_F320AD74_6E1D3B62_8BA79B98_59F741E0_82542A38_5502F25D_BF55296C_3A545E38_72760AB7,
0x3617DE4A_96262C6F_5D9E98BF_9292DC29_F8F41DBD_289A147C_E9DA3113_B5F0B8C0_0A60B1CE_1D7E819D_7A431D7C_90EA0E5F)

90
tests/test_curve_point.py Normal file
View File

@ -0,0 +1,90 @@
import pytest
from umbral.openssl import ErrorInvalidCompressedPoint, ErrorInvalidPointEncoding
from umbral.curve_point import CurvePoint
from umbral.curve import CURVE
def test_random():
p1 = CurvePoint.random()
p2 = CurvePoint.random()
assert isinstance(p1, CurvePoint)
assert isinstance(p2, CurvePoint)
assert p1 != p2
def test_generator_point():
"""http://www.secg.org/SEC2-Ver-1.0.pdf Section 2.7.1"""
g1 = CurvePoint.generator()
g_compressed = 0x0279BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798
g_compressed_bytes = g_compressed.to_bytes(CURVE.field_element_size + 1, byteorder='big')
g2 = CurvePoint.from_bytes(g_compressed_bytes)
assert g1 == g2
def test_to_and_from_affine():
x = 17004608369308732328368332205668001941491834793934321461466076545247324070015
y = 69725941631324401609944843130171147910924748427773762412028916504484868631573
p = CurvePoint.from_affine(x, y)
assert p.to_affine() == (x, y)
def test_invalid_serialized_points():
field_order = 2**256 - 0x1000003D1
# A point on secp256k1
x = 17004608369308732328368332205668001941491834793934321461466076545247324070015
y = 69725941631324401609944843130171147910924748427773762412028916504484868631573
# Check it
assert (y**2 - x**3 - 7) % field_order == 0
# Should load
point_data = b'\x03' + x.to_bytes(CURVE.field_element_size, 'big')
p = CurvePoint.from_bytes(point_data)
# Make it invalid
bad_x = x - 1
assert (y**2 - bad_x**3 - 7) % field_order != 0
bad_x_data = b'\x03' + bad_x.to_bytes(CURVE.field_element_size, 'big')
with pytest.raises(ErrorInvalidCompressedPoint):
CurvePoint.from_bytes(bad_x_data)
# Valid x, invalid prefix
bad_format = b'\xff' + x.to_bytes(CURVE.field_element_size, 'big')
with pytest.raises(ErrorInvalidPointEncoding):
CurvePoint.from_bytes(bad_format)
def test_serialize_point_at_infinity():
p = CurvePoint.random()
point_at_infinity = p - p
bytes_point_at_infinity = bytes(point_at_infinity)
assert bytes_point_at_infinity == b'\x00'
def test_coords_with_special_characteristics():
# Testing that a point with x coordinate greater than the curve order is still valid.
# In particular, we will test the last valid point from the default curve (secp256k1)
# whose x coordinate is `field_order - 3` and is greater than the order of the curve
field_order = 2**256 - 0x1000003D1
compressed = b'\x02' + (field_order-3).to_bytes(32, 'big')
last_point = CurvePoint.from_bytes(compressed)
# The same point, but obtained through the from_affine method
x = 115792089237316195423570985008687907853269984665640564039457584007908834671660
y = 109188863561374057667848968960504138135859662956057034999983532397866404169138
assert last_point == CurvePoint.from_affine(x, y)

127
tests/test_curve_scalar.py Normal file
View File

@ -0,0 +1,127 @@
import pytest
from umbral.curve import CURVE
from umbral.curve_scalar import CurveScalar
from umbral.hashing import Hash
def test_random():
r1 = CurveScalar.random_nonzero()
r2 = CurveScalar.random_nonzero()
assert r1 != r2
assert not r1.is_zero()
assert not r2.is_zero()
def test_from_and_to_int():
zero = CurveScalar.from_int(0)
assert zero.is_zero()
assert int(zero) == 0
one = CurveScalar.one()
assert not one.is_zero()
assert int(one) == 1
big_int = CURVE.order - 2
big_scalar = CurveScalar.from_int(big_int)
assert int(big_scalar) == big_int
# normalization check
with pytest.raises(ValueError):
CurveScalar.from_int(CURVE.order)
# disable normalization check
too_big = CurveScalar.from_int(CURVE.order, check_normalization=False)
def test_from_digest():
digest = Hash(b'asdf')
digest.update(b'some info')
s1 = CurveScalar.from_digest(digest)
digest = Hash(b'asdf')
digest.update(b'some info')
s2 = CurveScalar.from_digest(digest)
assert s1 == s2
assert int(s1) == int(s2)
def test_eq():
random = CurveScalar.random_nonzero()
same = CurveScalar.from_int(int(random))
different = CurveScalar.random_nonzero()
assert random == same
assert random == int(same)
assert random != different
assert random != int(different)
def test_serialization_rotations_of_1():
size_in_bytes = CURVE.scalar_size
for i in range(size_in_bytes):
lonely_one = 1 << i
bn = CurveScalar.from_int(lonely_one)
lonely_one_in_bytes = lonely_one.to_bytes(size_in_bytes, 'big')
# Check serialization
assert bytes(bn) == lonely_one_in_bytes
# Check deserialization
assert CurveScalar.from_bytes(lonely_one_in_bytes) == bn
def test_invalid_deserialization():
size_in_bytes = CURVE.scalar_size
# All-ones bytestring is invalid (since it's greater than the order)
lots_of_ones = b'\xFF' * size_in_bytes
with pytest.raises(ValueError):
CurveScalar.from_bytes(lots_of_ones)
# Serialization of `order` is invalid since it's not strictly lower than
# the order of the curve
order = CURVE.order
with pytest.raises(ValueError):
CurveScalar.from_bytes(order.to_bytes(size_in_bytes, 'big'))
# On the other hand, serialization of `order - 1` is valid
order -= 1
CurveScalar.from_bytes(order.to_bytes(size_in_bytes, 'big'))
def test_add():
r1 = CurveScalar.random_nonzero()
r2 = CurveScalar.random_nonzero()
r1i = int(r1)
r2i = int(r2)
assert r1 + r2 == (r1i + r2i) % CURVE.order
assert r1 + r2i == (r1i + r2i) % CURVE.order
def test_sub():
r1 = CurveScalar.random_nonzero()
r2 = CurveScalar.random_nonzero()
r1i = int(r1)
r2i = int(r2)
assert r1 - r2 == (r1i - r2i) % CURVE.order
assert r1 - r2i == (r1i - r2i) % CURVE.order
def test_mul():
r1 = CurveScalar.random_nonzero()
r2 = CurveScalar.random_nonzero()
r1i = int(r1)
r2i = int(r2)
assert r1 * r2 == (r1i * r2i) % CURVE.order
assert r1 * r2i == (r1i * r2i) % CURVE.order
def test_invert():
r1 = CurveScalar.random_nonzero()
r1i = int(r1)
r1inv = r1.invert()
assert r1 * r1inv == CurveScalar.one()
assert (r1i * int(r1inv)) % CURVE.order == 1

80
tests/test_dem.py Normal file
View File

@ -0,0 +1,80 @@
import pytest
import os
from umbral.dem import DEM, ErrorInvalidTag
def test_encrypt_decrypt():
key = os.urandom(DEM.KEY_SIZE)
dem = DEM(key)
plaintext = b'peace at dawn'
ciphertext0 = dem.encrypt(plaintext)
ciphertext1 = dem.encrypt(plaintext)
assert ciphertext0 != plaintext
assert ciphertext1 != plaintext
# Ciphertext should be different even with same plaintext.
assert ciphertext0 != ciphertext1
# Nonce should be different
assert ciphertext0[:DEM.NONCE_SIZE] != ciphertext1[:DEM.NONCE_SIZE]
cleartext0 = dem.decrypt(ciphertext0)
cleartext1 = dem.decrypt(ciphertext1)
assert cleartext0 == plaintext
assert cleartext1 == plaintext
def test_malformed_ciphertext():
key = os.urandom(DEM.KEY_SIZE)
dem = DEM(key)
plaintext = b'peace at dawn'
ciphertext = dem.encrypt(plaintext)
# So short it we can tell right away it doesn't even contain a nonce
with pytest.raises(ValueError, match="The ciphertext must include the nonce"):
dem.decrypt(ciphertext[:DEM.NONCE_SIZE-1])
# Too short to contain a tag
with pytest.raises(ValueError, match="The authentication tag is missing or malformed"):
dem.decrypt(ciphertext[:DEM.NONCE_SIZE + DEM.TAG_SIZE - 1])
# Too long
with pytest.raises(ErrorInvalidTag):
dem.decrypt(ciphertext + b'abcd')
def test_encrypt_decrypt_associated_data():
key = os.urandom(32)
aad = b'secret code 1234'
dem = DEM(key)
plaintext = b'peace at dawn'
ciphertext0 = dem.encrypt(plaintext, authenticated_data=aad)
ciphertext1 = dem.encrypt(plaintext, authenticated_data=aad)
assert ciphertext0 != plaintext
assert ciphertext1 != plaintext
assert ciphertext0 != ciphertext1
assert ciphertext0[:DEM.NONCE_SIZE] != ciphertext1[:DEM.NONCE_SIZE]
cleartext0 = dem.decrypt(ciphertext0, authenticated_data=aad)
cleartext1 = dem.decrypt(ciphertext1, authenticated_data=aad)
assert cleartext0 == plaintext
assert cleartext1 == plaintext
# Attempt decryption with invalid associated data
with pytest.raises(ErrorInvalidTag):
cleartext2 = dem.decrypt(ciphertext0, authenticated_data=b'wrong data')

126
tests/test_key_frag.py Normal file
View File

@ -0,0 +1,126 @@
import pytest
from umbral import KeyFrag, PublicKey, generate_kfrags
from umbral.key_frag import KeyFragID
from umbral.curve_scalar import CurveScalar
def test_kfrag_serialization(alices_keys, bobs_keys, kfrags):
delegating_sk, signing_sk = alices_keys
_receiving_sk, receiving_pk = bobs_keys
signing_pk = PublicKey.from_secret_key(signing_sk)
delegating_pk = PublicKey.from_secret_key(delegating_sk)
for kfrag in kfrags:
kfrag_bytes = bytes(kfrag)
new_kfrag = KeyFrag.from_bytes(kfrag_bytes)
assert new_kfrag.verify(signing_pk=signing_pk,
delegating_pk=delegating_pk,
receiving_pk=receiving_pk)
assert new_kfrag == kfrag
def test_kfrag_verification(alices_keys, bobs_keys, kfrags):
delegating_sk, signing_sk = alices_keys
_receiving_sk, receiving_pk = bobs_keys
signing_pk = PublicKey.from_secret_key(signing_sk)
delegating_pk = PublicKey.from_secret_key(delegating_sk)
# Wrong signature
kfrag = kfrags[0]
kfrag.id = KeyFragID.random()
kfrag_bytes = bytes(kfrag)
new_kfrag = KeyFrag.from_bytes(kfrag_bytes)
assert not new_kfrag.verify(signing_pk=signing_pk,
delegating_pk=delegating_pk,
receiving_pk=receiving_pk)
# Wrong key
kfrag = kfrags[1]
kfrag.key = CurveScalar.random_nonzero()
kfrag_bytes = bytes(kfrag)
new_kfrag = KeyFrag.from_bytes(kfrag_bytes)
assert not new_kfrag.verify(signing_pk=signing_pk,
delegating_pk=delegating_pk,
receiving_pk=receiving_pk)
@pytest.mark.parametrize('sign_delegating_key',
[False, True],
ids=['sign_delegating_key', 'dont_sign_delegating_key'])
@pytest.mark.parametrize('sign_receiving_key',
[False, True],
ids=['sign_receiving_key', 'dont_sign_receiving_key'])
def test_kfrag_signing(alices_keys, bobs_keys, sign_delegating_key, sign_receiving_key):
delegating_sk, signing_sk = alices_keys
_receiving_sk, receiving_pk = bobs_keys
signing_pk = PublicKey.from_secret_key(signing_sk)
delegating_pk = PublicKey.from_secret_key(delegating_sk)
kfrags = generate_kfrags(delegating_sk=delegating_sk,
signing_sk=signing_sk,
receiving_pk=receiving_pk,
threshold=6,
num_kfrags=10,
sign_delegating_key=sign_delegating_key,
sign_receiving_key=sign_receiving_key)
kfrag = kfrags[0]
# serialize/deserialize to make sure sign_* fields are serialized correctly
kfrag = KeyFrag.from_bytes(bytes(kfrag))
for pass_delegating_key, pass_receiving_key in zip([False, True], [False, True]):
delegating_key_ok = (not sign_delegating_key) or pass_delegating_key
receiving_key_ok = (not sign_receiving_key) or pass_receiving_key
should_verify = delegating_key_ok and receiving_key_ok
result = kfrag.verify(signing_pk=signing_pk,
delegating_pk=delegating_pk if pass_delegating_key else None,
receiving_pk=receiving_pk if pass_receiving_key else None)
assert result == should_verify
def test_kfrag_is_hashable(kfrags):
assert hash(kfrags[0]) != hash(kfrags[1])
new_kfrag = KeyFrag.from_bytes(bytes(kfrags[0]))
assert hash(new_kfrag) == hash(kfrags[0])
def test_kfrag_str(kfrags):
s = str(kfrags[0])
assert "KeyFrag" in s
WRONG_PARAMETERS = (
# (num_kfrags, threshold)
(-1, -1), (-1, 0), (-1, 5),
(0, -1), (0, 0), (0, 5),
(1, -1), (1, 0), (1, 5),
(5, -1), (5, 0), (5, 10)
)
@pytest.mark.parametrize("num_kfrags, threshold", WRONG_PARAMETERS)
def test_wrong_threshold_and_num_kfrags(num_kfrags, threshold, alices_keys, bobs_keys):
delegating_sk, signing_sk = alices_keys
_receiving_sk, receiving_pk = bobs_keys
with pytest.raises(ValueError):
generate_kfrags(delegating_sk=delegating_sk,
signing_sk=signing_sk,
receiving_pk=receiving_pk,
threshold=threshold,
num_kfrags=num_kfrags)

203
tests/test_keys.py Normal file
View File

@ -0,0 +1,203 @@
import os
import string
import pytest
from umbral.keys import PublicKey, SecretKey, SecretKeyFactory, Signature
from umbral.hashing import Hash
def test_gen_key():
sk = SecretKey.random()
assert type(sk) == SecretKey
pk = PublicKey.from_secret_key(sk)
assert type(pk) == PublicKey
pk2 = PublicKey.from_secret_key(sk)
assert pk == pk2
def test_derive_key_from_label():
factory = SecretKeyFactory.random()
label = b"my_healthcare_information"
sk1 = factory.secret_key_by_label(label)
assert type(sk1) == SecretKey
pk1 = PublicKey.from_secret_key(sk1)
assert type(pk1) == PublicKey
# Check that key derivation is reproducible
sk2 = factory.secret_key_by_label(label)
pk2 = PublicKey.from_secret_key(sk2)
assert sk1 == sk2
assert pk1 == pk2
# Different labels on the same master secret create different keys
label = b"my_tax_information"
sk3 = factory.secret_key_by_label(label)
pk3 = PublicKey.from_secret_key(sk3)
assert sk1 != sk3
def test_secret_key_serialization():
sk = SecretKey.random()
encoded_key = bytes(sk)
decoded_key = SecretKey.from_bytes(encoded_key)
assert sk == decoded_key
def test_secret_key_str():
sk = SecretKey.random()
s = str(sk)
assert s == "SecretKey:..."
def test_secret_key_hash():
sk = SecretKey.random()
# Insecure Python hash, shouldn't be available.
with pytest.raises(NotImplementedError):
hash(sk)
def test_secret_key_factory_str():
skf = SecretKeyFactory.random()
s = str(skf)
assert s == "SecretKeyFactory:..."
def test_secret_key_factory_hash():
skf = SecretKeyFactory.random()
# Insecure Python hash, shouldn't be available.
with pytest.raises(NotImplementedError):
hash(skf)
def test_public_key_serialization():
sk = SecretKey.random()
pk = PublicKey.from_secret_key(sk)
encoded_key = bytes(pk)
decoded_key = PublicKey.from_bytes(encoded_key)
assert pk == decoded_key
def test_public_key_point():
pk = PublicKey.from_secret_key(SecretKey.random())
assert bytes(pk) == bytes(pk.point())
def test_public_key_str():
pk = PublicKey.from_secret_key(SecretKey.random())
s = str(pk)
assert 'PublicKey' in s
def test_keying_material_serialization():
factory = SecretKeyFactory.random()
encoded_factory = bytes(factory)
decoded_factory = SecretKeyFactory.from_bytes(encoded_factory)
label = os.urandom(32)
sk1 = factory.secret_key_by_label(label)
sk2 = decoded_factory.secret_key_by_label(label)
assert sk1 == sk2
def test_public_key_is_hashable():
sk = SecretKey.random()
pk = PublicKey.from_secret_key(sk)
sk2 = SecretKey.random()
pk2 = PublicKey.from_secret_key(sk2)
assert hash(pk) != hash(pk2)
pk3 = PublicKey.from_bytes(bytes(pk))
assert hash(pk) == hash(pk3)
@pytest.mark.parametrize('execution_number', range(20)) # Run this test 20 times.
def test_sign_and_verify(execution_number):
sk = SecretKey.random()
pk = PublicKey.from_secret_key(sk)
message = b"peace at dawn"
dst = b"dst"
digest = Hash(dst)
digest.update(message)
signature = sk.sign_digest(digest)
digest = Hash(dst)
digest.update(message)
assert signature.verify_digest(pk, digest)
@pytest.mark.parametrize('execution_number', range(20)) # Run this test 20 times.
def test_sign_serialize_and_verify(execution_number):
sk = SecretKey.random()
pk = PublicKey.from_secret_key(sk)
message = b"peace at dawn"
dst = b"dst"
digest = Hash(dst)
digest.update(message)
signature = sk.sign_digest(digest)
signature_bytes = bytes(signature)
signature_restored = Signature.from_bytes(signature_bytes)
digest = Hash(dst)
digest.update(message)
assert signature_restored.verify_digest(pk, digest)
def test_verification_fail():
sk = SecretKey.random()
pk = PublicKey.from_secret_key(sk)
message = b"peace at dawn"
dst = b"dst"
digest = Hash(dst)
digest.update(message)
signature = sk.sign_digest(digest)
# wrong DST
digest = Hash(b"other dst")
digest.update(message)
assert not signature.verify_digest(pk, digest)
# wrong message
digest = Hash(dst)
digest.update(b"no peace at dawn")
assert not signature.verify_digest(pk, digest)
# bad signature
signature_bytes = bytes(signature)
signature_bytes = b'\x00' + signature_bytes[1:]
signature_restored = Signature.from_bytes(signature_bytes)
digest = Hash(dst)
digest.update(message)
assert not signature_restored.verify_digest(pk, digest)
def test_signature_repr():
sk = SecretKey.random()
pk = PublicKey.from_secret_key(sk)
message = b"peace at dawn"
dst = b"dst"
digest = Hash(dst)
digest.update(message)
signature = sk.sign_digest(digest)
s = repr(signature)
assert 'Signature' in s

95
tests/test_pre.py Normal file
View File

@ -0,0 +1,95 @@
import pytest
from umbral import (
SecretKey,
PublicKey,
encrypt,
generate_kfrags,
decrypt_original,
reencrypt,
decrypt_reencrypted,
)
from umbral.dem import ErrorInvalidTag
def test_public_key_encryption(alices_keys):
delegating_sk, _ = alices_keys
delegating_pk = PublicKey.from_secret_key(delegating_sk)
plaintext = b'peace at dawn'
capsule, ciphertext = encrypt(delegating_pk, plaintext)
plaintext_decrypted = decrypt_original(delegating_sk, capsule, ciphertext)
assert plaintext == plaintext_decrypted
# Wrong secret key
sk = SecretKey.random()
with pytest.raises(ErrorInvalidTag):
decrypt_original(sk, capsule, ciphertext)
SIMPLE_API_PARAMETERS = (
# (num_kfrags, threshold)
(1, 1),
(6, 1),
(6, 4),
(6, 6),
(50, 30)
)
@pytest.mark.parametrize("num_kfrags, threshold", SIMPLE_API_PARAMETERS)
def test_simple_api(num_kfrags, threshold):
"""
This test models the main interactions between actors (i.e., Alice,
Bob, Data Source, and Ursulas) and artifacts (i.e., public and private keys,
ciphertexts, capsules, KFrags, CFrags, etc).
The test covers all the main stages of data sharing:
key generation, delegation, encryption, decryption by
Alice, re-encryption by Ursula, and decryption by Bob.
"""
# Key Generation (Alice)
delegating_sk = SecretKey.random()
delegating_pk = PublicKey.from_secret_key(delegating_sk)
signing_sk = SecretKey.random()
signing_pk = PublicKey.from_secret_key(signing_sk)
# Key Generation (Bob)
receiving_sk = SecretKey.random()
receiving_pk = PublicKey.from_secret_key(receiving_sk)
# Encryption by an unnamed data source
plaintext = b'peace at dawn'
capsule, ciphertext = encrypt(delegating_pk, plaintext)
# Decryption by Alice
plaintext_decrypted = decrypt_original(delegating_sk, capsule, ciphertext)
assert plaintext_decrypted == plaintext
# Split Re-Encryption Key Generation (aka Delegation)
kfrags = generate_kfrags(delegating_sk, receiving_pk, signing_sk, threshold, num_kfrags)
# Bob requests re-encryption to some set of M ursulas
cfrags = list()
for kfrag in kfrags[:threshold]:
# Ursula checks that the received kfrag is valid
assert kfrag.verify(signing_pk, delegating_pk, receiving_pk)
# Re-encryption by an Ursula
cfrag = reencrypt(capsule, kfrag)
# Bob collects the result
cfrags.append(cfrag)
# Bob checks that the received cfrags are valid
assert all(cfrag.verify(capsule, delegating_pk, receiving_pk, signing_pk) for cfrag in cfrags)
# Decryption by Bob
plaintext_reenc = decrypt_reencrypted(receiving_sk,
delegating_pk,
capsule,
cfrags[:threshold],
ciphertext,
)
assert plaintext_reenc == plaintext

View File

@ -0,0 +1,92 @@
import re
import pytest
from umbral.serializable import Serializable, serialize_bool, take_bool
class A(Serializable):
def __init__(self, val: int):
assert 0 <= val < 2**32
self.val = val
@classmethod
def __take__(cls, data):
val_bytes, data = cls.__take_bytes__(data, 4)
return cls(int.from_bytes(val_bytes, byteorder='big')), data
def __bytes__(self):
return self.val.to_bytes(4, byteorder='big')
def __eq__(self, other):
return isinstance(other, A) and self.val == other.val
class B(Serializable):
def __init__(self, val: int):
assert 0 <= val < 2**16
self.val = val
@classmethod
def __take__(cls, data):
val_bytes, data = cls.__take_bytes__(data, 2)
return cls(int.from_bytes(val_bytes, byteorder='big')), data
def __bytes__(self):
return self.val.to_bytes(2, byteorder='big')
def __eq__(self, other):
return isinstance(other, B) and self.val == other.val
class C(Serializable):
def __init__(self, a: A, b: B):
self.a = a
self.b = b
@classmethod
def __take__(cls, data):
components, data = cls.__take_types__(data, A, B)
return cls(*components), data
def __bytes__(self):
return bytes(self.a) + bytes(self.b)
def __eq__(self, other):
return isinstance(other, C) and self.a == other.a and self.b == other.b
def test_normal_operation():
a = A(2**32 - 123)
b = B(2**16 - 456)
c = C(a, b)
c_back = C.from_bytes(bytes(c))
assert c_back == c
def test_too_many_bytes():
a = A(2**32 - 123)
b = B(2**16 - 456)
c = C(a, b)
with pytest.raises(ValueError, match="1 bytes remaining after deserializing"):
C.from_bytes(bytes(c) + b'\x00')
def test_not_enough_bytes():
a = A(2**32 - 123)
b = B(2**16 - 456)
c = C(a, b)
# Will happen on deserialization of B - 1 byte missing
with pytest.raises(ValueError, match="cannot take 2 bytes from a bytestring of size 1"):
C.from_bytes(bytes(c)[:-1])
def test_serialize_bool():
assert take_bool(serialize_bool(True) + b'1234') == (True, b'1234')
assert take_bool(serialize_bool(False) + b'12') == (False, b'12')
error_msg = re.escape("Incorrectly serialized boolean; expected b'\\x00' or b'\\x01', got b'z'")
with pytest.raises(ValueError, match=error_msg):
take_bool(b'z1234')

View File

@ -4,11 +4,13 @@ from typing import Optional
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes
import nacl
from nacl.bindings.crypto_aead import (
crypto_aead_xchacha20poly1305_ietf_encrypt as xchacha_encrypt,
crypto_aead_xchacha20poly1305_ietf_decrypt as xchacha_decrypt,
crypto_aead_xchacha20poly1305_ietf_KEYBYTES as XCHACHA_KEY_SIZE,
crypto_aead_xchacha20poly1305_ietf_NPUBBYTES as XCHACHA_NONCE_SIZE,
crypto_aead_xchacha20poly1305_ietf_ABYTES as XCHACHA_TAG_SIZE,
)
from . import openssl
@ -28,10 +30,15 @@ def kdf(data: bytes,
return hkdf.derive(data)
class ErrorInvalidTag(Exception):
pass
class DEM:
KEY_SIZE = XCHACHA_KEY_SIZE
NONCE_SIZE = XCHACHA_NONCE_SIZE
TAG_SIZE = XCHACHA_TAG_SIZE
def __init__(self,
key_material: bytes,
@ -53,5 +60,11 @@ class DEM:
nonce = nonce_and_ciphertext[:self.NONCE_SIZE]
ciphertext = nonce_and_ciphertext[self.NONCE_SIZE:]
# TODO: replace `nacl.exceptions.CryptoError` with our error?
return xchacha_decrypt(ciphertext, authenticated_data, nonce, self._key)
# Prevent an out of bounds error deep in NaCl
if len(ciphertext) < self.TAG_SIZE:
raise ValueError(f"The authentication tag is missing or malformed")
try:
return xchacha_decrypt(ciphertext, authenticated_data, nonce, self._key)
except nacl.exceptions.CryptoError:
raise ErrorInvalidTag