Some type annotation improvements and other minor change requests

pull/220/head
David Núñez 2018-10-08 18:24:17 +02:00
parent 7d9ddfe6e8
commit 87b24a0083
13 changed files with 84 additions and 72 deletions

View File

@ -203,7 +203,7 @@ def test_umbral_public_key_as_dict_key():
another_umbral_pub_key = another_umbral_priv_key.get_pubkey() another_umbral_pub_key = another_umbral_priv_key.get_pubkey()
with pytest.raises(KeyError): with pytest.raises(KeyError):
d[another_umbral_pub_key] _ = d[another_umbral_pub_key]
d[another_umbral_pub_key] = False d[another_umbral_pub_key] = False

View File

@ -55,9 +55,9 @@ def prove_cfrag_correctness(cfrag: 'CapsuleFrag',
v2 = t * v v2 = t * v
u2 = t * u u2 = t * u
hash_input = (e, e1, e2, v, v1, v2, u, u1, u2) hash_input = [e, e1, e2, v, v1, v2, u, u1, u2]
if metadata is not None: if metadata is not None:
hash_input += (metadata,) hash_input.append(metadata)
h = CurveBN.hash(*hash_input, params=params) h = CurveBN.hash(*hash_input, params=params)
######## ########
@ -98,9 +98,9 @@ def assess_cfrag_correctness(cfrag: 'CapsuleFrag', capsule: 'Capsule') -> bool:
else: else:
raise raise
hash_input = (e, e1, e2, v, v1, v2, u, u1, u2) hash_input = [e, e1, e2, v, v1, v2, u, u1, u2]
if cfrag.proof.metadata is not None: if cfrag.proof.metadata is not None:
hash_input += (cfrag.proof.metadata,) hash_input.append(cfrag.proof.metadata)
h = CurveBN.hash(*hash_input, params=params) h = CurveBN.hash(*hash_input, params=params)
######## ########

View File

@ -28,7 +28,11 @@ class _CONFIG:
__curve = None __curve = None
__params = None __params = None
__CURVE_TO_USE_IF_NO_DEFAULT_IS_SET_BY_USER = SECP256K1 __CURVE_TO_USE_IF_NO_DEFAULT_IS_SET_BY_USER = SECP256K1
__WARNING_IF_NO_DEFAULT_SET = "No default curve has been set. Using SECP256K1. A slight performance penalty has been incurred for only this call. Set a default curve with umbral.config.set_default_curve()." __WARNING_IF_NO_DEFAULT_SET = "No default curve has been set. " \
"Using SECP256K1. " \
"A slight performance penalty has been " \
"incurred for only this call. Set a default " \
"curve with umbral.config.set_default_curve()."
class UmbralConfigurationError(RuntimeError): class UmbralConfigurationError(RuntimeError):
"""Raised when somebody does something dumb re: configuration.""" """Raised when somebody does something dumb re: configuration."""
@ -42,13 +46,13 @@ class _CONFIG:
def params(cls) -> UmbralParameters: def params(cls) -> UmbralParameters:
if not cls.__params: if not cls.__params:
cls.__set_curve_by_default() cls.__set_curve_by_default()
return cls.__params return cls.__params # type: ignore
@classmethod @classmethod
def curve(cls) -> Curve: def curve(cls) -> Curve:
if not cls.__curve: if not cls.__curve:
cls.__set_curve_by_default() cls.__set_curve_by_default()
return cls.__curve return cls.__curve # type: ignore
@classmethod @classmethod
def set_curve(cls, curve: Optional[Curve] = None) -> None: def set_curve(cls, curve: Optional[Curve] = None) -> None:

View File

@ -56,8 +56,8 @@ class Curve:
self.__generator = openssl._get_ec_generator_by_group(self.ec_group) self.__generator = openssl._get_ec_generator_by_group(self.ec_group)
# Init cache # Init cache
self.__field_order_size_in_bytes = None self.__field_order_size_in_bytes = 0
self.__group_order_size_in_bytes = None self.__group_order_size_in_bytes = 0
@classmethod @classmethod
def from_name(cls, name: str) -> 'Curve': def from_name(cls, name: str) -> 'Curve':
@ -93,15 +93,14 @@ class Curve:
@property @property
def field_order_size_in_bytes(self) -> int: def field_order_size_in_bytes(self) -> int:
if self.__field_order_size_in_bytes is None: if not self.__field_order_size_in_bytes:
backend = default_backend()
size_in_bits = openssl._get_ec_group_degree(self.__ec_group) size_in_bits = openssl._get_ec_group_degree(self.__ec_group)
self.__field_order_size_in_bytes = (size_in_bits + 7) // 8 self.__field_order_size_in_bytes = (size_in_bits + 7) // 8
return self.__field_order_size_in_bytes return self.__field_order_size_in_bytes
@property @property
def group_order_size_in_bytes(self) -> int: def group_order_size_in_bytes(self) -> int:
if self.__group_order_size_in_bytes is None: if not self.__group_order_size_in_bytes:
BN_num_bytes = default_backend()._lib.BN_num_bytes BN_num_bytes = default_backend()._lib.BN_num_bytes
self.__group_order_size_in_bytes = BN_num_bytes(self.order) self.__group_order_size_in_bytes = BN_num_bytes(self.order)
return self.__group_order_size_in_bytes return self.__group_order_size_in_bytes

View File

@ -17,7 +17,7 @@ You should have received a copy of the GNU General Public License
along with pyUmbral. If not, see <https://www.gnu.org/licenses/>. along with pyUmbral. If not, see <https://www.gnu.org/licenses/>.
""" """
from typing import Optional, Union from typing import Optional, Union, cast
from cryptography.hazmat.backends.openssl import backend from cryptography.hazmat.backends.openssl import backend
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
@ -36,7 +36,7 @@ class CurveBN(object):
constant time operations. constant time operations.
""" """
def __init__(self, bignum, curve: Curve): def __init__(self, bignum, curve: Curve) -> None:
on_curve = openssl._bn_is_on_curve(bignum, curve) on_curve = openssl._bn_is_on_curve(bignum, curve)
if not on_curve: if not on_curve:
raise ValueError("The provided BIGNUM is not on the provided curve.") raise ValueError("The provided BIGNUM is not on the provided curve.")
@ -146,7 +146,7 @@ class CurveBN(object):
""" """
return backend._bn_to_int(self.bignum) return backend._bn_to_int(self.bignum)
def __eq__(self, other : Union[int, 'CurveBN']) -> bool: def __eq__(self, other) -> bool:
""" """
Compares the two BIGNUMS or int. Compares the two BIGNUMS or int.
""" """
@ -169,6 +169,8 @@ class CurveBN(object):
other = openssl._int_to_bn(other) other = openssl._int_to_bn(other)
other = CurveBN(other, self.curve) other = CurveBN(other, self.curve)
other = cast('CurveBN', other) # This is just for mypy
power = openssl._get_new_BN() power = openssl._get_new_BN()
with backend._tmp_bn_ctx() as bn_ctx, openssl._tmp_bn_mont_ctx(self.curve.order) as bn_mont_ctx: with backend._tmp_bn_ctx() as bn_ctx, openssl._tmp_bn_mont_ctx(self.curve.order) as bn_mont_ctx:
res = backend._lib.BN_mod_exp_mont( res = backend._lib.BN_mod_exp_mont(
@ -215,7 +217,6 @@ class CurveBN(object):
return CurveBN(product, self.curve) return CurveBN(product, self.curve)
def __add__(self, other : Union[int, 'CurveBN']) -> 'CurveBN': def __add__(self, other : Union[int, 'CurveBN']) -> 'CurveBN':
""" """
Performs a BN_mod_add on two BIGNUMs. Performs a BN_mod_add on two BIGNUMs.
@ -224,6 +225,8 @@ class CurveBN(object):
other = openssl._int_to_bn(other) other = openssl._int_to_bn(other)
other = CurveBN(other, self.curve) other = CurveBN(other, self.curve)
other = cast('CurveBN', other) # This is just for mypy
op_sum = openssl._get_new_BN() op_sum = openssl._get_new_BN()
with backend._tmp_bn_ctx() as bn_ctx: with backend._tmp_bn_ctx() as bn_ctx:
res = backend._lib.BN_mod_add( res = backend._lib.BN_mod_add(
@ -241,6 +244,8 @@ class CurveBN(object):
other = openssl._int_to_bn(other) other = openssl._int_to_bn(other)
other = CurveBN(other, self.curve) other = CurveBN(other, self.curve)
other = cast('CurveBN', other) # This is just for mypy
diff = openssl._get_new_BN() diff = openssl._get_new_BN()
with backend._tmp_bn_ctx() as bn_ctx: with backend._tmp_bn_ctx() as bn_ctx:
res = backend._lib.BN_mod_sub( res = backend._lib.BN_mod_sub(
@ -291,6 +296,8 @@ class CurveBN(object):
other = openssl._int_to_bn(other) other = openssl._int_to_bn(other)
other = CurveBN(other, self.curve) other = CurveBN(other, self.curve)
other = cast('CurveBN', other) # This is just for mypy
rem = openssl._get_new_BN() rem = openssl._get_new_BN()
with backend._tmp_bn_ctx() as bn_ctx: with backend._tmp_bn_ctx() as bn_ctx:
res = backend._lib.BN_nnmod( res = backend._lib.BN_nnmod(

View File

@ -21,7 +21,6 @@ import hmac
from typing import Optional from typing import Optional
from bytestring_splitter import BytestringSplitter from bytestring_splitter import BytestringSplitter
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve
from umbral._pre import assess_cfrag_correctness, verify_kfrag from umbral._pre import assess_cfrag_correctness, verify_kfrag
from umbral.config import default_curve, default_params from umbral.config import default_curve, default_params
@ -30,6 +29,7 @@ from umbral.keys import UmbralPublicKey
from umbral.point import Point from umbral.point import Point
from umbral.signing import Signature from umbral.signing import Signature
from umbral.params import UmbralParameters from umbral.params import UmbralParameters
from umbral.curve import Curve
from constant_sorrow.constants import NO_KEY, DELEGATING_ONLY, RECEIVING_ONLY, DELEGATING_AND_RECEIVING from constant_sorrow.constants import NO_KEY, DELEGATING_ONLY, RECEIVING_ONLY, DELEGATING_AND_RECEIVING
@ -64,7 +64,7 @@ class KFrag(object):
""" """
@classmethod @classmethod
def expected_bytes_length(cls, curve: Optional[EllipticCurve] = None) -> int: def expected_bytes_length(cls, curve: Optional[Curve] = None) -> int:
""" """
Returns the size (in bytes) of a KFrag given the curve. Returns the size (in bytes) of a KFrag given the curve.
If no curve is provided, it will use the default curve. If no curve is provided, it will use the default curve.
@ -84,7 +84,7 @@ class KFrag(object):
return bn_size * 6 + point_size * 2 + 1 return bn_size * 6 + point_size * 2 + 1
@classmethod @classmethod
def from_bytes(cls, data: bytes, curve: Optional[EllipticCurve] = None) -> 'KFrag': def from_bytes(cls, data: bytes, curve: Optional[Curve] = None) -> 'KFrag':
""" """
Instantiate a KFrag object from the serialized data. Instantiate a KFrag object from the serialized data.
""" """
@ -184,7 +184,7 @@ class CorrectnessProof(object):
self.kfrag_signature = kfrag_signature self.kfrag_signature = kfrag_signature
@classmethod @classmethod
def expected_bytes_length(cls, curve: Optional[EllipticCurve] = None): def expected_bytes_length(cls, curve: Optional[Curve] = None):
""" """
Returns the size (in bytes) of a CorrectnessProof without the metadata. Returns the size (in bytes) of a CorrectnessProof without the metadata.
If no curve is given, it will use the default curve. If no curve is given, it will use the default curve.
@ -196,7 +196,7 @@ class CorrectnessProof(object):
return (bn_size * 3) + (point_size * 4) return (bn_size * 3) + (point_size * 4)
@classmethod @classmethod
def from_bytes(cls, data: bytes, curve: Optional[EllipticCurve] = None) -> 'CorrectnessProof': def from_bytes(cls, data: bytes, curve: Optional[Curve] = None) -> 'CorrectnessProof':
""" """
Instantiate CorrectnessProof from serialized data. Instantiate CorrectnessProof from serialized data.
""" """
@ -213,9 +213,9 @@ class CorrectnessProof(object):
(Signature, Signature.expected_bytes_length(curve), arguments), # kfrag_signature (Signature, Signature.expected_bytes_length(curve), arguments), # kfrag_signature
) )
components = splitter(data, return_remainder=True) components = splitter(data, return_remainder=True)
metadata = components.pop(-1) or None components.append(components.pop() or None)
return cls(*components, metadata=metadata) return cls(*components)
def to_bytes(self) -> bytes: def to_bytes(self) -> bytes:
""" """
@ -260,7 +260,7 @@ class CapsuleFrag(object):
""" """
@classmethod @classmethod
def expected_bytes_length(cls, curve: Optional[EllipticCurve] = None) -> int: def expected_bytes_length(cls, curve: Optional[Curve] = None) -> int:
""" """
Returns the size (in bytes) of a CapsuleFrag given the curve without Returns the size (in bytes) of a CapsuleFrag given the curve without
the CorrectnessProof. the CorrectnessProof.
@ -273,7 +273,7 @@ class CapsuleFrag(object):
return (bn_size * 1) + (point_size * 3) return (bn_size * 1) + (point_size * 3)
@classmethod @classmethod
def from_bytes(cls, data: bytes, curve: Optional[EllipticCurve] = None) -> 'CapsuleFrag': def from_bytes(cls, data: bytes, curve: Optional[Curve] = None) -> 'CapsuleFrag':
""" """
Instantiates a CapsuleFrag object from the serialized data. Instantiates a CapsuleFrag object from the serialized data.
""" """
@ -291,9 +291,10 @@ class CapsuleFrag(object):
) )
components = splitter(data, return_remainder=True) components = splitter(data, return_remainder=True)
proof = components.pop(-1) or None proof = components.pop() or None
proof = CorrectnessProof.from_bytes(proof, curve) if proof else None components.append(CorrectnessProof.from_bytes(proof, curve) if proof else None)
return cls(*components, proof=proof)
return cls(*components)
def to_bytes(self) -> bytes: def to_bytes(self) -> bytes:
""" """

View File

@ -18,7 +18,7 @@ along with pyUmbral. If not, see <https://www.gnu.org/licenses/>.
""" """
import os import os
from typing import Callable, Optional, Union from typing import Callable, Optional, Union, Any
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.backends.openssl.ec import _EllipticCurvePrivateKey, _EllipticCurvePublicKey from cryptography.hazmat.backends.openssl.ec import _EllipticCurvePrivateKey, _EllipticCurvePublicKey
@ -42,7 +42,7 @@ class UmbralPrivateKey(object):
""" """
self.params = params self.params = params
self.bn_key = bn_key self.bn_key = bn_key
self.pubkey = UmbralPublicKey(self.bn_key * params.g, params=params) self.pubkey = UmbralPublicKey(self.bn_key * params.g, params=params) # type: ignore
@classmethod @classmethod
def gen_key(cls, params: Optional[UmbralParameters] = None) -> 'UmbralPrivateKey': def gen_key(cls, params: Optional[UmbralParameters] = None) -> 'UmbralPrivateKey':
@ -269,17 +269,17 @@ class UmbralPublicKey(object):
def __repr__(self): def __repr__(self):
return "{}:{}".format(self.__class__.__name__, self.point_key.to_bytes().hex()[:15]) return "{}:{}".format(self.__class__.__name__, self.point_key.to_bytes().hex()[:15])
def __eq__(self, other: Optional[Union[bytes, 'UmbralPublicKey', int]]) -> bool: def __eq__(self, other: Any) -> bool:
if type(other) == bytes: if type(other) == bytes:
is_eq = bytes(other) == bytes(self) is_eq = bytes(other) == bytes(self)
elif hasattr(other, "point_key"): elif hasattr(other, "point_key") and hasattr(other, "params"):
is_eq = self.point_key == other.point_key is_eq = (self.point_key, self.params) == (other.point_key, other.params)
else: else:
is_eq = False is_eq = False
return is_eq return is_eq
def __hash__(self) -> int: def __hash__(self) -> int:
return int.from_bytes(self, byteorder="big") return int.from_bytes(self.to_bytes(), byteorder="big")
class UmbralKeyingMaterial(object): class UmbralKeyingMaterial(object):

View File

@ -33,7 +33,7 @@ class UmbralParameters(object):
parameters_seed = b'NuCypher/UmbralParameters/' parameters_seed = b'NuCypher/UmbralParameters/'
self.u = unsafe_hash_to_point(g_bytes, self, parameters_seed + b'u') self.u = unsafe_hash_to_point(g_bytes, self, parameters_seed + b'u')
def __eq__(self, other: 'UmbralParameters') -> bool: def __eq__(self, other) -> bool:
# TODO: This is not comparing the order, which currently is an OpenSSL pointer # TODO: This is not comparing the order, which currently is an OpenSSL pointer
self_attributes = self.curve, self.g, self.CURVE_KEY_SIZE_BYTES, self.u self_attributes = self.curve, self.g, self.CURVE_KEY_SIZE_BYTES, self.u

View File

@ -161,7 +161,7 @@ class Point(object):
# 1 is not-equal, 0 is equal, -1 is error # 1 is not-equal, 0 is equal, -1 is error
return not bool(is_equal) return not bool(is_equal)
def __mul__(self, other) -> 'Point': def __mul__(self, other: CurveBN) -> 'Point':
""" """
Performs an EC_POINT_mul on an EC_POINT and a BIGNUM. Performs an EC_POINT_mul on an EC_POINT and a BIGNUM.
""" """

View File

@ -19,7 +19,7 @@ along with pyUmbral. If not, see <https://www.gnu.org/licenses/>.
import os import os
import typing import typing
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union, Any
from bytestring_splitter import BytestringSplitter from bytestring_splitter import BytestringSplitter
from umbral._pre import prove_cfrag_correctness from umbral._pre import prove_cfrag_correctness
@ -109,7 +109,7 @@ class Capsule(object):
components = splitter(capsule_bytes) components = splitter(capsule_bytes)
return cls(params, *components) return cls(params, *components)
def _set_cfrag_correctness_key(self, key_type: str, key: UmbralPublicKey) -> bool: def _set_cfrag_correctness_key(self, key_type: str, key: Optional[UmbralPublicKey]) -> bool:
if key_type not in ("delegating", "receiving", "verifying"): if key_type not in ("delegating", "receiving", "verifying"):
raise ValueError("You can only set 'delegating', 'receiving' or 'verifying' keys.") raise ValueError("You can only set 'delegating', 'receiving' or 'verifying' keys.")
@ -147,7 +147,8 @@ class Capsule(object):
""" """
Serialize the Capsule into a bytestring. Serialize the Capsule into a bytestring.
""" """
return bytes().join(c.to_bytes() for c in self.components()) e, v, s = self.components()
return e.to_bytes() + v.to_bytes() + s.to_bytes()
def verify(self) -> bool: def verify(self) -> bool:
@ -172,11 +173,11 @@ class Capsule(object):
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
return self.to_bytes() return self.to_bytes()
def __eq__(self, other: 'Capsule') -> bool: def __eq__(self, other) -> bool:
""" """
Each component is compared to its counterpart in constant time per the __eq__ of Point and CurveBN. Each component is compared to its counterpart in constant time per the __eq__ of Point and CurveBN.
""" """
return self.components() == other.components() and all(self.components()) return hasattr(other, "components") and self.components() == other.components() and all(self.components())
@typing.no_type_check @typing.no_type_check
def __hash__(self) -> int: def __hash__(self) -> int:
@ -218,14 +219,13 @@ def generate_kfrags(delegating_privkey: UmbralPrivateKey,
g = params.g g = params.g
delegating_pubkey = delegating_privkey.get_pubkey() delegating_pubkey = delegating_privkey.get_pubkey()
delegating_privkey = delegating_privkey.bn_key
bob_pubkey_point = receiving_pubkey.point_key bob_pubkey_point = receiving_pubkey.point_key
# The precursor point is used as an ephemeral public key in a DH key exchange, # The precursor point is used as an ephemeral public key in a DH key exchange,
# and the resulting shared secret 'dh_point' is used to derive other secret values # and the resulting shared secret 'dh_point' is used to derive other secret values
private_precursor = CurveBN.gen_rand(params.curve) private_precursor = CurveBN.gen_rand(params.curve)
precursor = private_precursor * g precursor = private_precursor * g # type: Any
dh_point = private_precursor * bob_pubkey_point dh_point = private_precursor * bob_pubkey_point
@ -239,7 +239,7 @@ def generate_kfrags(delegating_privkey: UmbralPrivateKey,
params=params) params=params)
# Coefficients of the generating polynomial # Coefficients of the generating polynomial
coefficients = [delegating_privkey * (~d)] coefficients = [delegating_privkey.bn_key * (~d)]
coefficients += [CurveBN.gen_rand(params.curve) for _ in range(threshold - 1)] coefficients += [CurveBN.gen_rand(params.curve) for _ in range(threshold - 1)]
bn_size = CurveBN.expected_bytes_length(params.curve) bn_size = CurveBN.expected_bytes_length(params.curve)
@ -263,14 +263,14 @@ def generate_kfrags(delegating_privkey: UmbralPrivateKey,
# polynomial for the index value # polynomial for the index value
rk = poly_eval(coefficients, share_index) rk = poly_eval(coefficients, share_index)
commitment = rk * params.u commitment = rk * params.u # type: Any
validity_message_for_bob = (kfrag_id, validity_message_for_bob = (kfrag_id,
delegating_pubkey, delegating_pubkey,
receiving_pubkey, receiving_pubkey,
commitment, commitment,
precursor, precursor,
) ) # type: Any
validity_message_for_bob = bytes().join(bytes(item) for item in validity_message_for_bob) validity_message_for_bob = bytes().join(bytes(item) for item in validity_message_for_bob)
signature_for_bob = signer(validity_message_for_bob) signature_for_bob = signer(validity_message_for_bob)
@ -283,7 +283,7 @@ def generate_kfrags(delegating_privkey: UmbralPrivateKey,
else: else:
mode = NO_KEY mode = NO_KEY
validity_message_for_proxy = [kfrag_id, commitment, precursor, mode] validity_message_for_proxy = [kfrag_id, commitment, precursor, mode] # type: Any
if sign_delegating_key: if sign_delegating_key:
validity_message_for_proxy.append(delegating_pubkey) validity_message_for_proxy.append(delegating_pubkey)
@ -316,8 +316,8 @@ def reencrypt(kfrag: KFrag, capsule: Capsule, provide_proof: bool = True,
raise KFrag.NotValid raise KFrag.NotValid
rk = kfrag._bn_key rk = kfrag._bn_key
e1 = rk * capsule._point_e e1 = rk * capsule._point_e # type: Any
v1 = rk * capsule._point_v v1 = rk * capsule._point_v # type: Any
cfrag = CapsuleFrag(point_e1=e1, point_v1=v1, kfrag_id=kfrag.id, cfrag = CapsuleFrag(point_e1=e1, point_v1=v1, kfrag_id=kfrag.id,
point_precursor=kfrag._point_precursor) point_precursor=kfrag._point_precursor)
@ -336,15 +336,15 @@ def _encapsulate(alice_pubkey: UmbralPublicKey,
g = params.g g = params.g
priv_r = CurveBN.gen_rand(params.curve) priv_r = CurveBN.gen_rand(params.curve)
pub_r = priv_r * g pub_r = priv_r * g # type: Any
priv_u = CurveBN.gen_rand(params.curve) priv_u = CurveBN.gen_rand(params.curve)
pub_u = priv_u * g pub_u = priv_u * g # type: Any
h = CurveBN.hash(pub_r, pub_u, params=params) h = CurveBN.hash(pub_r, pub_u, params=params)
s = priv_u + (priv_r * h) s = priv_u + (priv_r * h)
shared_key = (priv_r + priv_u) * alice_pubkey.point_key shared_key = (priv_r + priv_u) * alice_pubkey.point_key # type: Any
# Key to be used for symmetric encryption # Key to be used for symmetric encryption
key = kdf(shared_key, key_length) key = kdf(shared_key, key_length)
@ -361,7 +361,7 @@ def _decapsulate_original(priv_key: UmbralPrivateKey,
# Check correctness of original ciphertext # Check correctness of original ciphertext
raise capsule.NotValid("Capsule verification failed.") raise capsule.NotValid("Capsule verification failed.")
shared_key = priv_key.bn_key * (capsule._point_e + capsule._point_v) shared_key = priv_key.bn_key * (capsule._point_e + capsule._point_v) # type: Any
key = kdf(shared_key, key_length) key = kdf(shared_key, key_length)
return key return key
@ -415,7 +415,7 @@ def _decapsulate_reencrypted(receiving_privkey: UmbralPrivateKey, capsule: Capsu
e, v, s = capsule.components() e, v, s = capsule.components()
h = CurveBN.hash(e, v, params=params) h = CurveBN.hash(e, v, params=params)
orig_pub_key = capsule.get_correctness_keys()['delegating'].point_key orig_pub_key = capsule.get_correctness_keys()['delegating'].point_key # type: ignore
if not (s / d) * orig_pub_key == (h * e_prime) + v_prime: if not (s / d) * orig_pub_key == (h * e_prime) + v_prime:
raise GenericUmbralError() raise GenericUmbralError()

View File

@ -105,7 +105,7 @@ class Signature:
def __radd__(self, other: bytes) -> bytes: def __radd__(self, other: bytes) -> bytes:
return other + bytes(self) return other + bytes(self)
def __eq__(self, other: 'Signature') -> bool: def __eq__(self, other) -> bool:
simple_bytes_match = hmac.compare_digest(bytes(self), bytes(other)) simple_bytes_match = hmac.compare_digest(bytes(self), bytes(other))
der_encoded_match = hmac.compare_digest(self._der_encoded_bytes(), bytes(other)) der_encoded_match = hmac.compare_digest(self._der_encoded_bytes(), bytes(other))
return simple_bytes_match or der_encoded_match return simple_bytes_match or der_encoded_match

View File

@ -24,13 +24,14 @@ from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from umbral.curvebn import CurveBN from umbral.curvebn import CurveBN
from umbral.point import Point
def lambda_coeff(id_i: CurveBN, selected_ids: List[CurveBN]) -> CurveBN: def lambda_coeff(id_i: CurveBN, selected_ids: List[CurveBN]) -> CurveBN:
ids = [x for x in selected_ids if x != id_i] ids = [x for x in selected_ids if x != id_i]
if not ids: if not ids:
return None CurveBN.from_int(1, id_i.curve)
result = ids[0] / (ids[0] - id_i) result = ids[0] / (ids[0] - id_i)
for id_j in ids[1:]: for id_j in ids[1:]:
@ -47,7 +48,7 @@ def poly_eval(coeff: List[CurveBN], x: CurveBN) -> CurveBN:
return result return result
def kdf(ecpoint: 'Point', key_length: int) -> bytes: def kdf(ecpoint: Point, key_length: int) -> bytes:
data = ecpoint.to_bytes(is_compressed=True) data = ecpoint.to_bytes(is_compressed=True)
return HKDF( return HKDF(