mirror of https://github.com/nucypher/pyUmbral.git
Some type annotation improvements and other minor change requests
parent
7d9ddfe6e8
commit
87b24a0083
|
@ -203,7 +203,7 @@ def test_umbral_public_key_as_dict_key():
|
||||||
another_umbral_pub_key = another_umbral_priv_key.get_pubkey()
|
another_umbral_pub_key = another_umbral_priv_key.get_pubkey()
|
||||||
|
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
d[another_umbral_pub_key]
|
_ = d[another_umbral_pub_key]
|
||||||
|
|
||||||
d[another_umbral_pub_key] = False
|
d[another_umbral_pub_key] = False
|
||||||
|
|
||||||
|
|
|
@ -55,9 +55,9 @@ def prove_cfrag_correctness(cfrag: 'CapsuleFrag',
|
||||||
v2 = t * v
|
v2 = t * v
|
||||||
u2 = t * u
|
u2 = t * u
|
||||||
|
|
||||||
hash_input = (e, e1, e2, v, v1, v2, u, u1, u2)
|
hash_input = [e, e1, e2, v, v1, v2, u, u1, u2]
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
hash_input += (metadata,)
|
hash_input.append(metadata)
|
||||||
h = CurveBN.hash(*hash_input, params=params)
|
h = CurveBN.hash(*hash_input, params=params)
|
||||||
########
|
########
|
||||||
|
|
||||||
|
@ -98,9 +98,9 @@ def assess_cfrag_correctness(cfrag: 'CapsuleFrag', capsule: 'Capsule') -> bool:
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
hash_input = (e, e1, e2, v, v1, v2, u, u1, u2)
|
hash_input = [e, e1, e2, v, v1, v2, u, u1, u2]
|
||||||
if cfrag.proof.metadata is not None:
|
if cfrag.proof.metadata is not None:
|
||||||
hash_input += (cfrag.proof.metadata,)
|
hash_input.append(cfrag.proof.metadata)
|
||||||
h = CurveBN.hash(*hash_input, params=params)
|
h = CurveBN.hash(*hash_input, params=params)
|
||||||
########
|
########
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,11 @@ class _CONFIG:
|
||||||
__curve = None
|
__curve = None
|
||||||
__params = None
|
__params = None
|
||||||
__CURVE_TO_USE_IF_NO_DEFAULT_IS_SET_BY_USER = SECP256K1
|
__CURVE_TO_USE_IF_NO_DEFAULT_IS_SET_BY_USER = SECP256K1
|
||||||
__WARNING_IF_NO_DEFAULT_SET = "No default curve has been set. Using SECP256K1. A slight performance penalty has been incurred for only this call. Set a default curve with umbral.config.set_default_curve()."
|
__WARNING_IF_NO_DEFAULT_SET = "No default curve has been set. " \
|
||||||
|
"Using SECP256K1. " \
|
||||||
|
"A slight performance penalty has been " \
|
||||||
|
"incurred for only this call. Set a default " \
|
||||||
|
"curve with umbral.config.set_default_curve()."
|
||||||
|
|
||||||
class UmbralConfigurationError(RuntimeError):
|
class UmbralConfigurationError(RuntimeError):
|
||||||
"""Raised when somebody does something dumb re: configuration."""
|
"""Raised when somebody does something dumb re: configuration."""
|
||||||
|
@ -42,13 +46,13 @@ class _CONFIG:
|
||||||
def params(cls) -> UmbralParameters:
|
def params(cls) -> UmbralParameters:
|
||||||
if not cls.__params:
|
if not cls.__params:
|
||||||
cls.__set_curve_by_default()
|
cls.__set_curve_by_default()
|
||||||
return cls.__params
|
return cls.__params # type: ignore
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def curve(cls) -> Curve:
|
def curve(cls) -> Curve:
|
||||||
if not cls.__curve:
|
if not cls.__curve:
|
||||||
cls.__set_curve_by_default()
|
cls.__set_curve_by_default()
|
||||||
return cls.__curve
|
return cls.__curve # type: ignore
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_curve(cls, curve: Optional[Curve] = None) -> None:
|
def set_curve(cls, curve: Optional[Curve] = None) -> None:
|
||||||
|
|
|
@ -56,8 +56,8 @@ class Curve:
|
||||||
self.__generator = openssl._get_ec_generator_by_group(self.ec_group)
|
self.__generator = openssl._get_ec_generator_by_group(self.ec_group)
|
||||||
|
|
||||||
# Init cache
|
# Init cache
|
||||||
self.__field_order_size_in_bytes = None
|
self.__field_order_size_in_bytes = 0
|
||||||
self.__group_order_size_in_bytes = None
|
self.__group_order_size_in_bytes = 0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_name(cls, name: str) -> 'Curve':
|
def from_name(cls, name: str) -> 'Curve':
|
||||||
|
@ -93,15 +93,14 @@ class Curve:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def field_order_size_in_bytes(self) -> int:
|
def field_order_size_in_bytes(self) -> int:
|
||||||
if self.__field_order_size_in_bytes is None:
|
if not self.__field_order_size_in_bytes:
|
||||||
backend = default_backend()
|
|
||||||
size_in_bits = openssl._get_ec_group_degree(self.__ec_group)
|
size_in_bits = openssl._get_ec_group_degree(self.__ec_group)
|
||||||
self.__field_order_size_in_bytes = (size_in_bits + 7) // 8
|
self.__field_order_size_in_bytes = (size_in_bits + 7) // 8
|
||||||
return self.__field_order_size_in_bytes
|
return self.__field_order_size_in_bytes
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def group_order_size_in_bytes(self) -> int:
|
def group_order_size_in_bytes(self) -> int:
|
||||||
if self.__group_order_size_in_bytes is None:
|
if not self.__group_order_size_in_bytes:
|
||||||
BN_num_bytes = default_backend()._lib.BN_num_bytes
|
BN_num_bytes = default_backend()._lib.BN_num_bytes
|
||||||
self.__group_order_size_in_bytes = BN_num_bytes(self.order)
|
self.__group_order_size_in_bytes = BN_num_bytes(self.order)
|
||||||
return self.__group_order_size_in_bytes
|
return self.__group_order_size_in_bytes
|
||||||
|
|
|
@ -17,7 +17,7 @@ You should have received a copy of the GNU General Public License
|
||||||
along with pyUmbral. If not, see <https://www.gnu.org/licenses/>.
|
along with pyUmbral. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
from cryptography.hazmat.backends.openssl import backend
|
from cryptography.hazmat.backends.openssl import backend
|
||||||
from cryptography.hazmat.primitives import hashes
|
from cryptography.hazmat.primitives import hashes
|
||||||
|
@ -36,7 +36,7 @@ class CurveBN(object):
|
||||||
constant time operations.
|
constant time operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, bignum, curve: Curve):
|
def __init__(self, bignum, curve: Curve) -> None:
|
||||||
on_curve = openssl._bn_is_on_curve(bignum, curve)
|
on_curve = openssl._bn_is_on_curve(bignum, curve)
|
||||||
if not on_curve:
|
if not on_curve:
|
||||||
raise ValueError("The provided BIGNUM is not on the provided curve.")
|
raise ValueError("The provided BIGNUM is not on the provided curve.")
|
||||||
|
@ -146,7 +146,7 @@ class CurveBN(object):
|
||||||
"""
|
"""
|
||||||
return backend._bn_to_int(self.bignum)
|
return backend._bn_to_int(self.bignum)
|
||||||
|
|
||||||
def __eq__(self, other : Union[int, 'CurveBN']) -> bool:
|
def __eq__(self, other) -> bool:
|
||||||
"""
|
"""
|
||||||
Compares the two BIGNUMS or int.
|
Compares the two BIGNUMS or int.
|
||||||
"""
|
"""
|
||||||
|
@ -169,6 +169,8 @@ class CurveBN(object):
|
||||||
other = openssl._int_to_bn(other)
|
other = openssl._int_to_bn(other)
|
||||||
other = CurveBN(other, self.curve)
|
other = CurveBN(other, self.curve)
|
||||||
|
|
||||||
|
other = cast('CurveBN', other) # This is just for mypy
|
||||||
|
|
||||||
power = openssl._get_new_BN()
|
power = openssl._get_new_BN()
|
||||||
with backend._tmp_bn_ctx() as bn_ctx, openssl._tmp_bn_mont_ctx(self.curve.order) as bn_mont_ctx:
|
with backend._tmp_bn_ctx() as bn_ctx, openssl._tmp_bn_mont_ctx(self.curve.order) as bn_mont_ctx:
|
||||||
res = backend._lib.BN_mod_exp_mont(
|
res = backend._lib.BN_mod_exp_mont(
|
||||||
|
@ -215,7 +217,6 @@ class CurveBN(object):
|
||||||
|
|
||||||
return CurveBN(product, self.curve)
|
return CurveBN(product, self.curve)
|
||||||
|
|
||||||
|
|
||||||
def __add__(self, other : Union[int, 'CurveBN']) -> 'CurveBN':
|
def __add__(self, other : Union[int, 'CurveBN']) -> 'CurveBN':
|
||||||
"""
|
"""
|
||||||
Performs a BN_mod_add on two BIGNUMs.
|
Performs a BN_mod_add on two BIGNUMs.
|
||||||
|
@ -224,6 +225,8 @@ class CurveBN(object):
|
||||||
other = openssl._int_to_bn(other)
|
other = openssl._int_to_bn(other)
|
||||||
other = CurveBN(other, self.curve)
|
other = CurveBN(other, self.curve)
|
||||||
|
|
||||||
|
other = cast('CurveBN', other) # This is just for mypy
|
||||||
|
|
||||||
op_sum = openssl._get_new_BN()
|
op_sum = openssl._get_new_BN()
|
||||||
with backend._tmp_bn_ctx() as bn_ctx:
|
with backend._tmp_bn_ctx() as bn_ctx:
|
||||||
res = backend._lib.BN_mod_add(
|
res = backend._lib.BN_mod_add(
|
||||||
|
@ -241,6 +244,8 @@ class CurveBN(object):
|
||||||
other = openssl._int_to_bn(other)
|
other = openssl._int_to_bn(other)
|
||||||
other = CurveBN(other, self.curve)
|
other = CurveBN(other, self.curve)
|
||||||
|
|
||||||
|
other = cast('CurveBN', other) # This is just for mypy
|
||||||
|
|
||||||
diff = openssl._get_new_BN()
|
diff = openssl._get_new_BN()
|
||||||
with backend._tmp_bn_ctx() as bn_ctx:
|
with backend._tmp_bn_ctx() as bn_ctx:
|
||||||
res = backend._lib.BN_mod_sub(
|
res = backend._lib.BN_mod_sub(
|
||||||
|
@ -291,6 +296,8 @@ class CurveBN(object):
|
||||||
other = openssl._int_to_bn(other)
|
other = openssl._int_to_bn(other)
|
||||||
other = CurveBN(other, self.curve)
|
other = CurveBN(other, self.curve)
|
||||||
|
|
||||||
|
other = cast('CurveBN', other) # This is just for mypy
|
||||||
|
|
||||||
rem = openssl._get_new_BN()
|
rem = openssl._get_new_BN()
|
||||||
with backend._tmp_bn_ctx() as bn_ctx:
|
with backend._tmp_bn_ctx() as bn_ctx:
|
||||||
res = backend._lib.BN_nnmod(
|
res = backend._lib.BN_nnmod(
|
||||||
|
|
|
@ -21,7 +21,6 @@ import hmac
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from bytestring_splitter import BytestringSplitter
|
from bytestring_splitter import BytestringSplitter
|
||||||
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve
|
|
||||||
|
|
||||||
from umbral._pre import assess_cfrag_correctness, verify_kfrag
|
from umbral._pre import assess_cfrag_correctness, verify_kfrag
|
||||||
from umbral.config import default_curve, default_params
|
from umbral.config import default_curve, default_params
|
||||||
|
@ -30,6 +29,7 @@ from umbral.keys import UmbralPublicKey
|
||||||
from umbral.point import Point
|
from umbral.point import Point
|
||||||
from umbral.signing import Signature
|
from umbral.signing import Signature
|
||||||
from umbral.params import UmbralParameters
|
from umbral.params import UmbralParameters
|
||||||
|
from umbral.curve import Curve
|
||||||
|
|
||||||
from constant_sorrow.constants import NO_KEY, DELEGATING_ONLY, RECEIVING_ONLY, DELEGATING_AND_RECEIVING
|
from constant_sorrow.constants import NO_KEY, DELEGATING_ONLY, RECEIVING_ONLY, DELEGATING_AND_RECEIVING
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ class KFrag(object):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def expected_bytes_length(cls, curve: Optional[EllipticCurve] = None) -> int:
|
def expected_bytes_length(cls, curve: Optional[Curve] = None) -> int:
|
||||||
"""
|
"""
|
||||||
Returns the size (in bytes) of a KFrag given the curve.
|
Returns the size (in bytes) of a KFrag given the curve.
|
||||||
If no curve is provided, it will use the default curve.
|
If no curve is provided, it will use the default curve.
|
||||||
|
@ -84,7 +84,7 @@ class KFrag(object):
|
||||||
return bn_size * 6 + point_size * 2 + 1
|
return bn_size * 6 + point_size * 2 + 1
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_bytes(cls, data: bytes, curve: Optional[EllipticCurve] = None) -> 'KFrag':
|
def from_bytes(cls, data: bytes, curve: Optional[Curve] = None) -> 'KFrag':
|
||||||
"""
|
"""
|
||||||
Instantiate a KFrag object from the serialized data.
|
Instantiate a KFrag object from the serialized data.
|
||||||
"""
|
"""
|
||||||
|
@ -184,7 +184,7 @@ class CorrectnessProof(object):
|
||||||
self.kfrag_signature = kfrag_signature
|
self.kfrag_signature = kfrag_signature
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def expected_bytes_length(cls, curve: Optional[EllipticCurve] = None):
|
def expected_bytes_length(cls, curve: Optional[Curve] = None):
|
||||||
"""
|
"""
|
||||||
Returns the size (in bytes) of a CorrectnessProof without the metadata.
|
Returns the size (in bytes) of a CorrectnessProof without the metadata.
|
||||||
If no curve is given, it will use the default curve.
|
If no curve is given, it will use the default curve.
|
||||||
|
@ -196,7 +196,7 @@ class CorrectnessProof(object):
|
||||||
return (bn_size * 3) + (point_size * 4)
|
return (bn_size * 3) + (point_size * 4)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_bytes(cls, data: bytes, curve: Optional[EllipticCurve] = None) -> 'CorrectnessProof':
|
def from_bytes(cls, data: bytes, curve: Optional[Curve] = None) -> 'CorrectnessProof':
|
||||||
"""
|
"""
|
||||||
Instantiate CorrectnessProof from serialized data.
|
Instantiate CorrectnessProof from serialized data.
|
||||||
"""
|
"""
|
||||||
|
@ -213,9 +213,9 @@ class CorrectnessProof(object):
|
||||||
(Signature, Signature.expected_bytes_length(curve), arguments), # kfrag_signature
|
(Signature, Signature.expected_bytes_length(curve), arguments), # kfrag_signature
|
||||||
)
|
)
|
||||||
components = splitter(data, return_remainder=True)
|
components = splitter(data, return_remainder=True)
|
||||||
metadata = components.pop(-1) or None
|
components.append(components.pop() or None)
|
||||||
|
|
||||||
return cls(*components, metadata=metadata)
|
return cls(*components)
|
||||||
|
|
||||||
def to_bytes(self) -> bytes:
|
def to_bytes(self) -> bytes:
|
||||||
"""
|
"""
|
||||||
|
@ -260,7 +260,7 @@ class CapsuleFrag(object):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def expected_bytes_length(cls, curve: Optional[EllipticCurve] = None) -> int:
|
def expected_bytes_length(cls, curve: Optional[Curve] = None) -> int:
|
||||||
"""
|
"""
|
||||||
Returns the size (in bytes) of a CapsuleFrag given the curve without
|
Returns the size (in bytes) of a CapsuleFrag given the curve without
|
||||||
the CorrectnessProof.
|
the CorrectnessProof.
|
||||||
|
@ -273,7 +273,7 @@ class CapsuleFrag(object):
|
||||||
return (bn_size * 1) + (point_size * 3)
|
return (bn_size * 1) + (point_size * 3)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_bytes(cls, data: bytes, curve: Optional[EllipticCurve] = None) -> 'CapsuleFrag':
|
def from_bytes(cls, data: bytes, curve: Optional[Curve] = None) -> 'CapsuleFrag':
|
||||||
"""
|
"""
|
||||||
Instantiates a CapsuleFrag object from the serialized data.
|
Instantiates a CapsuleFrag object from the serialized data.
|
||||||
"""
|
"""
|
||||||
|
@ -291,9 +291,10 @@ class CapsuleFrag(object):
|
||||||
)
|
)
|
||||||
components = splitter(data, return_remainder=True)
|
components = splitter(data, return_remainder=True)
|
||||||
|
|
||||||
proof = components.pop(-1) or None
|
proof = components.pop() or None
|
||||||
proof = CorrectnessProof.from_bytes(proof, curve) if proof else None
|
components.append(CorrectnessProof.from_bytes(proof, curve) if proof else None)
|
||||||
return cls(*components, proof=proof)
|
|
||||||
|
return cls(*components)
|
||||||
|
|
||||||
def to_bytes(self) -> bytes:
|
def to_bytes(self) -> bytes:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -18,7 +18,7 @@ along with pyUmbral. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Optional, Union, Any
|
||||||
|
|
||||||
from cryptography.hazmat.backends import default_backend
|
from cryptography.hazmat.backends import default_backend
|
||||||
from cryptography.hazmat.backends.openssl.ec import _EllipticCurvePrivateKey, _EllipticCurvePublicKey
|
from cryptography.hazmat.backends.openssl.ec import _EllipticCurvePrivateKey, _EllipticCurvePublicKey
|
||||||
|
@ -42,7 +42,7 @@ class UmbralPrivateKey(object):
|
||||||
"""
|
"""
|
||||||
self.params = params
|
self.params = params
|
||||||
self.bn_key = bn_key
|
self.bn_key = bn_key
|
||||||
self.pubkey = UmbralPublicKey(self.bn_key * params.g, params=params)
|
self.pubkey = UmbralPublicKey(self.bn_key * params.g, params=params) # type: ignore
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def gen_key(cls, params: Optional[UmbralParameters] = None) -> 'UmbralPrivateKey':
|
def gen_key(cls, params: Optional[UmbralParameters] = None) -> 'UmbralPrivateKey':
|
||||||
|
@ -269,17 +269,17 @@ class UmbralPublicKey(object):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "{}:{}".format(self.__class__.__name__, self.point_key.to_bytes().hex()[:15])
|
return "{}:{}".format(self.__class__.__name__, self.point_key.to_bytes().hex()[:15])
|
||||||
|
|
||||||
def __eq__(self, other: Optional[Union[bytes, 'UmbralPublicKey', int]]) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
if type(other) == bytes:
|
if type(other) == bytes:
|
||||||
is_eq = bytes(other) == bytes(self)
|
is_eq = bytes(other) == bytes(self)
|
||||||
elif hasattr(other, "point_key"):
|
elif hasattr(other, "point_key") and hasattr(other, "params"):
|
||||||
is_eq = self.point_key == other.point_key
|
is_eq = (self.point_key, self.params) == (other.point_key, other.params)
|
||||||
else:
|
else:
|
||||||
is_eq = False
|
is_eq = False
|
||||||
return is_eq
|
return is_eq
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
return int.from_bytes(self, byteorder="big")
|
return int.from_bytes(self.to_bytes(), byteorder="big")
|
||||||
|
|
||||||
|
|
||||||
class UmbralKeyingMaterial(object):
|
class UmbralKeyingMaterial(object):
|
||||||
|
|
|
@ -33,7 +33,7 @@ class UmbralParameters(object):
|
||||||
parameters_seed = b'NuCypher/UmbralParameters/'
|
parameters_seed = b'NuCypher/UmbralParameters/'
|
||||||
self.u = unsafe_hash_to_point(g_bytes, self, parameters_seed + b'u')
|
self.u = unsafe_hash_to_point(g_bytes, self, parameters_seed + b'u')
|
||||||
|
|
||||||
def __eq__(self, other: 'UmbralParameters') -> bool:
|
def __eq__(self, other) -> bool:
|
||||||
|
|
||||||
# TODO: This is not comparing the order, which currently is an OpenSSL pointer
|
# TODO: This is not comparing the order, which currently is an OpenSSL pointer
|
||||||
self_attributes = self.curve, self.g, self.CURVE_KEY_SIZE_BYTES, self.u
|
self_attributes = self.curve, self.g, self.CURVE_KEY_SIZE_BYTES, self.u
|
||||||
|
|
|
@ -161,7 +161,7 @@ class Point(object):
|
||||||
# 1 is not-equal, 0 is equal, -1 is error
|
# 1 is not-equal, 0 is equal, -1 is error
|
||||||
return not bool(is_equal)
|
return not bool(is_equal)
|
||||||
|
|
||||||
def __mul__(self, other) -> 'Point':
|
def __mul__(self, other: CurveBN) -> 'Point':
|
||||||
"""
|
"""
|
||||||
Performs an EC_POINT_mul on an EC_POINT and a BIGNUM.
|
Performs an EC_POINT_mul on an EC_POINT and a BIGNUM.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -19,7 +19,7 @@ along with pyUmbral. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import typing
|
import typing
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union, Any
|
||||||
|
|
||||||
from bytestring_splitter import BytestringSplitter
|
from bytestring_splitter import BytestringSplitter
|
||||||
from umbral._pre import prove_cfrag_correctness
|
from umbral._pre import prove_cfrag_correctness
|
||||||
|
@ -109,7 +109,7 @@ class Capsule(object):
|
||||||
components = splitter(capsule_bytes)
|
components = splitter(capsule_bytes)
|
||||||
return cls(params, *components)
|
return cls(params, *components)
|
||||||
|
|
||||||
def _set_cfrag_correctness_key(self, key_type: str, key: UmbralPublicKey) -> bool:
|
def _set_cfrag_correctness_key(self, key_type: str, key: Optional[UmbralPublicKey]) -> bool:
|
||||||
if key_type not in ("delegating", "receiving", "verifying"):
|
if key_type not in ("delegating", "receiving", "verifying"):
|
||||||
raise ValueError("You can only set 'delegating', 'receiving' or 'verifying' keys.")
|
raise ValueError("You can only set 'delegating', 'receiving' or 'verifying' keys.")
|
||||||
|
|
||||||
|
@ -147,7 +147,8 @@ class Capsule(object):
|
||||||
"""
|
"""
|
||||||
Serialize the Capsule into a bytestring.
|
Serialize the Capsule into a bytestring.
|
||||||
"""
|
"""
|
||||||
return bytes().join(c.to_bytes() for c in self.components())
|
e, v, s = self.components()
|
||||||
|
return e.to_bytes() + v.to_bytes() + s.to_bytes()
|
||||||
|
|
||||||
def verify(self) -> bool:
|
def verify(self) -> bool:
|
||||||
|
|
||||||
|
@ -172,11 +173,11 @@ class Capsule(object):
|
||||||
def __bytes__(self) -> bytes:
|
def __bytes__(self) -> bytes:
|
||||||
return self.to_bytes()
|
return self.to_bytes()
|
||||||
|
|
||||||
def __eq__(self, other: 'Capsule') -> bool:
|
def __eq__(self, other) -> bool:
|
||||||
"""
|
"""
|
||||||
Each component is compared to its counterpart in constant time per the __eq__ of Point and CurveBN.
|
Each component is compared to its counterpart in constant time per the __eq__ of Point and CurveBN.
|
||||||
"""
|
"""
|
||||||
return self.components() == other.components() and all(self.components())
|
return hasattr(other, "components") and self.components() == other.components() and all(self.components())
|
||||||
|
|
||||||
@typing.no_type_check
|
@typing.no_type_check
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
|
@ -218,14 +219,13 @@ def generate_kfrags(delegating_privkey: UmbralPrivateKey,
|
||||||
g = params.g
|
g = params.g
|
||||||
|
|
||||||
delegating_pubkey = delegating_privkey.get_pubkey()
|
delegating_pubkey = delegating_privkey.get_pubkey()
|
||||||
delegating_privkey = delegating_privkey.bn_key
|
|
||||||
|
|
||||||
bob_pubkey_point = receiving_pubkey.point_key
|
bob_pubkey_point = receiving_pubkey.point_key
|
||||||
|
|
||||||
# The precursor point is used as an ephemeral public key in a DH key exchange,
|
# The precursor point is used as an ephemeral public key in a DH key exchange,
|
||||||
# and the resulting shared secret 'dh_point' is used to derive other secret values
|
# and the resulting shared secret 'dh_point' is used to derive other secret values
|
||||||
private_precursor = CurveBN.gen_rand(params.curve)
|
private_precursor = CurveBN.gen_rand(params.curve)
|
||||||
precursor = private_precursor * g
|
precursor = private_precursor * g # type: Any
|
||||||
|
|
||||||
dh_point = private_precursor * bob_pubkey_point
|
dh_point = private_precursor * bob_pubkey_point
|
||||||
|
|
||||||
|
@ -239,7 +239,7 @@ def generate_kfrags(delegating_privkey: UmbralPrivateKey,
|
||||||
params=params)
|
params=params)
|
||||||
|
|
||||||
# Coefficients of the generating polynomial
|
# Coefficients of the generating polynomial
|
||||||
coefficients = [delegating_privkey * (~d)]
|
coefficients = [delegating_privkey.bn_key * (~d)]
|
||||||
coefficients += [CurveBN.gen_rand(params.curve) for _ in range(threshold - 1)]
|
coefficients += [CurveBN.gen_rand(params.curve) for _ in range(threshold - 1)]
|
||||||
|
|
||||||
bn_size = CurveBN.expected_bytes_length(params.curve)
|
bn_size = CurveBN.expected_bytes_length(params.curve)
|
||||||
|
@ -263,14 +263,14 @@ def generate_kfrags(delegating_privkey: UmbralPrivateKey,
|
||||||
# polynomial for the index value
|
# polynomial for the index value
|
||||||
rk = poly_eval(coefficients, share_index)
|
rk = poly_eval(coefficients, share_index)
|
||||||
|
|
||||||
commitment = rk * params.u
|
commitment = rk * params.u # type: Any
|
||||||
|
|
||||||
validity_message_for_bob = (kfrag_id,
|
validity_message_for_bob = (kfrag_id,
|
||||||
delegating_pubkey,
|
delegating_pubkey,
|
||||||
receiving_pubkey,
|
receiving_pubkey,
|
||||||
commitment,
|
commitment,
|
||||||
precursor,
|
precursor,
|
||||||
)
|
) # type: Any
|
||||||
validity_message_for_bob = bytes().join(bytes(item) for item in validity_message_for_bob)
|
validity_message_for_bob = bytes().join(bytes(item) for item in validity_message_for_bob)
|
||||||
signature_for_bob = signer(validity_message_for_bob)
|
signature_for_bob = signer(validity_message_for_bob)
|
||||||
|
|
||||||
|
@ -283,7 +283,7 @@ def generate_kfrags(delegating_privkey: UmbralPrivateKey,
|
||||||
else:
|
else:
|
||||||
mode = NO_KEY
|
mode = NO_KEY
|
||||||
|
|
||||||
validity_message_for_proxy = [kfrag_id, commitment, precursor, mode]
|
validity_message_for_proxy = [kfrag_id, commitment, precursor, mode] # type: Any
|
||||||
|
|
||||||
if sign_delegating_key:
|
if sign_delegating_key:
|
||||||
validity_message_for_proxy.append(delegating_pubkey)
|
validity_message_for_proxy.append(delegating_pubkey)
|
||||||
|
@ -316,8 +316,8 @@ def reencrypt(kfrag: KFrag, capsule: Capsule, provide_proof: bool = True,
|
||||||
raise KFrag.NotValid
|
raise KFrag.NotValid
|
||||||
|
|
||||||
rk = kfrag._bn_key
|
rk = kfrag._bn_key
|
||||||
e1 = rk * capsule._point_e
|
e1 = rk * capsule._point_e # type: Any
|
||||||
v1 = rk * capsule._point_v
|
v1 = rk * capsule._point_v # type: Any
|
||||||
|
|
||||||
cfrag = CapsuleFrag(point_e1=e1, point_v1=v1, kfrag_id=kfrag.id,
|
cfrag = CapsuleFrag(point_e1=e1, point_v1=v1, kfrag_id=kfrag.id,
|
||||||
point_precursor=kfrag._point_precursor)
|
point_precursor=kfrag._point_precursor)
|
||||||
|
@ -336,15 +336,15 @@ def _encapsulate(alice_pubkey: UmbralPublicKey,
|
||||||
g = params.g
|
g = params.g
|
||||||
|
|
||||||
priv_r = CurveBN.gen_rand(params.curve)
|
priv_r = CurveBN.gen_rand(params.curve)
|
||||||
pub_r = priv_r * g
|
pub_r = priv_r * g # type: Any
|
||||||
|
|
||||||
priv_u = CurveBN.gen_rand(params.curve)
|
priv_u = CurveBN.gen_rand(params.curve)
|
||||||
pub_u = priv_u * g
|
pub_u = priv_u * g # type: Any
|
||||||
|
|
||||||
h = CurveBN.hash(pub_r, pub_u, params=params)
|
h = CurveBN.hash(pub_r, pub_u, params=params)
|
||||||
s = priv_u + (priv_r * h)
|
s = priv_u + (priv_r * h)
|
||||||
|
|
||||||
shared_key = (priv_r + priv_u) * alice_pubkey.point_key
|
shared_key = (priv_r + priv_u) * alice_pubkey.point_key # type: Any
|
||||||
|
|
||||||
# Key to be used for symmetric encryption
|
# Key to be used for symmetric encryption
|
||||||
key = kdf(shared_key, key_length)
|
key = kdf(shared_key, key_length)
|
||||||
|
@ -361,7 +361,7 @@ def _decapsulate_original(priv_key: UmbralPrivateKey,
|
||||||
# Check correctness of original ciphertext
|
# Check correctness of original ciphertext
|
||||||
raise capsule.NotValid("Capsule verification failed.")
|
raise capsule.NotValid("Capsule verification failed.")
|
||||||
|
|
||||||
shared_key = priv_key.bn_key * (capsule._point_e + capsule._point_v)
|
shared_key = priv_key.bn_key * (capsule._point_e + capsule._point_v) # type: Any
|
||||||
key = kdf(shared_key, key_length)
|
key = kdf(shared_key, key_length)
|
||||||
return key
|
return key
|
||||||
|
|
||||||
|
@ -415,7 +415,7 @@ def _decapsulate_reencrypted(receiving_privkey: UmbralPrivateKey, capsule: Capsu
|
||||||
e, v, s = capsule.components()
|
e, v, s = capsule.components()
|
||||||
h = CurveBN.hash(e, v, params=params)
|
h = CurveBN.hash(e, v, params=params)
|
||||||
|
|
||||||
orig_pub_key = capsule.get_correctness_keys()['delegating'].point_key
|
orig_pub_key = capsule.get_correctness_keys()['delegating'].point_key # type: ignore
|
||||||
|
|
||||||
if not (s / d) * orig_pub_key == (h * e_prime) + v_prime:
|
if not (s / d) * orig_pub_key == (h * e_prime) + v_prime:
|
||||||
raise GenericUmbralError()
|
raise GenericUmbralError()
|
||||||
|
|
|
@ -105,7 +105,7 @@ class Signature:
|
||||||
def __radd__(self, other: bytes) -> bytes:
|
def __radd__(self, other: bytes) -> bytes:
|
||||||
return other + bytes(self)
|
return other + bytes(self)
|
||||||
|
|
||||||
def __eq__(self, other: 'Signature') -> bool:
|
def __eq__(self, other) -> bool:
|
||||||
simple_bytes_match = hmac.compare_digest(bytes(self), bytes(other))
|
simple_bytes_match = hmac.compare_digest(bytes(self), bytes(other))
|
||||||
der_encoded_match = hmac.compare_digest(self._der_encoded_bytes(), bytes(other))
|
der_encoded_match = hmac.compare_digest(self._der_encoded_bytes(), bytes(other))
|
||||||
return simple_bytes_match or der_encoded_match
|
return simple_bytes_match or der_encoded_match
|
||||||
|
|
|
@ -24,13 +24,14 @@ from cryptography.hazmat.primitives import hashes
|
||||||
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
||||||
|
|
||||||
from umbral.curvebn import CurveBN
|
from umbral.curvebn import CurveBN
|
||||||
|
from umbral.point import Point
|
||||||
|
|
||||||
|
|
||||||
def lambda_coeff(id_i: CurveBN, selected_ids: List[CurveBN]) -> CurveBN:
|
def lambda_coeff(id_i: CurveBN, selected_ids: List[CurveBN]) -> CurveBN:
|
||||||
ids = [x for x in selected_ids if x != id_i]
|
ids = [x for x in selected_ids if x != id_i]
|
||||||
|
|
||||||
if not ids:
|
if not ids:
|
||||||
return None
|
CurveBN.from_int(1, id_i.curve)
|
||||||
|
|
||||||
result = ids[0] / (ids[0] - id_i)
|
result = ids[0] / (ids[0] - id_i)
|
||||||
for id_j in ids[1:]:
|
for id_j in ids[1:]:
|
||||||
|
@ -47,7 +48,7 @@ def poly_eval(coeff: List[CurveBN], x: CurveBN) -> CurveBN:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def kdf(ecpoint: 'Point', key_length: int) -> bytes:
|
def kdf(ecpoint: Point, key_length: int) -> bytes:
|
||||||
data = ecpoint.to_bytes(is_compressed=True)
|
data = ecpoint.to_bytes(is_compressed=True)
|
||||||
|
|
||||||
return HKDF(
|
return HKDF(
|
||||||
|
|
Loading…
Reference in New Issue