mirror of https://github.com/nucypher/pyUmbral.git
curve_scalar: don't check range in __init__, only in publicly used constructors
parent
d65969761c
commit
f58a2580dc
|
@ -17,9 +17,6 @@ class CurveScalar(Serializable):
|
|||
"""
|
||||
|
||||
def __init__(self, backend_bignum):
|
||||
if not openssl.bn_is_normalized(backend_bignum, CURVE.bn_order):
|
||||
raise ValueError("The provided BIGNUM is not on the provided curve.")
|
||||
|
||||
self._backend_bignum = backend_bignum
|
||||
|
||||
@classmethod
|
||||
|
@ -30,11 +27,12 @@ class CurveScalar(Serializable):
|
|||
return cls(openssl.bn_random_nonzero(CURVE.bn_order))
|
||||
|
||||
@classmethod
|
||||
def from_int(cls, num: int) -> 'CurveScalar':
|
||||
def from_int(cls, num: int, check_normalization: bool = True) -> 'CurveScalar':
|
||||
"""
|
||||
Returns a CurveScalar object from a given integer on a curve.
|
||||
"""
|
||||
conv_bn = openssl.bn_from_int(num, modulus=CURVE.bn_order)
|
||||
modulus = CURVE.bn_order if check_normalization else None
|
||||
conv_bn = openssl.bn_from_int(num, check_modulus=modulus)
|
||||
return cls(conv_bn)
|
||||
|
||||
@classmethod
|
||||
|
@ -43,13 +41,13 @@ class CurveScalar(Serializable):
|
|||
# Currently just matching what we have in RustCrypto stack
|
||||
# (taking bytes modulo curve order).
|
||||
# Can produce zeros!
|
||||
bn = openssl.bn_from_bytes(digest.finalize(), modulus=CURVE.bn_order)
|
||||
bn = openssl.bn_from_bytes(digest.finalize(), apply_modulus=CURVE.bn_order)
|
||||
return cls(bn)
|
||||
|
||||
@classmethod
|
||||
def __take__(cls, data: bytes) -> Tuple['CurveScalar', bytes]:
|
||||
scalar_data, data = cls.__take_bytes__(data, CURVE.scalar_size)
|
||||
bignum = openssl.bn_from_bytes(scalar_data)
|
||||
bignum = openssl.bn_from_bytes(scalar_data, check_modulus=CURVE.bn_order)
|
||||
return cls(bignum), data
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
|
|
|
@ -57,14 +57,18 @@ class SecretKey(Serializable):
|
|||
|
||||
backend_sk = openssl.bn_to_privkey(CURVE, self._scalar_key._backend_bignum)
|
||||
signature_der_bytes = backend_sk.sign(message, signature_algorithm)
|
||||
r, s = utils.decode_dss_signature(signature_der_bytes)
|
||||
r_int, s_int = utils.decode_dss_signature(signature_der_bytes)
|
||||
|
||||
# Normalize s
|
||||
# s is public, so no constant-timeness required here
|
||||
if s > (CURVE.order >> 1):
|
||||
s = CURVE.order - s
|
||||
if s_int > (CURVE.order >> 1):
|
||||
s_int = CURVE.order - s_int
|
||||
|
||||
return Signature(CurveScalar.from_int(r), CurveScalar.from_int(s))
|
||||
# Already normalized, don't waste time
|
||||
r = CurveScalar.from_int(r_int, check_normalization=False)
|
||||
s = CurveScalar.from_int(s_int, check_normalization=False)
|
||||
|
||||
return Signature(r, s)
|
||||
|
||||
|
||||
class Signature(Serializable):
|
||||
|
|
|
@ -144,7 +144,7 @@ def bn_is_normalized(check_bn, modulus):
|
|||
return (check_sign == 1 or check_sign == 0) and range_check == -1
|
||||
|
||||
|
||||
def bn_from_int(py_int: int, modulus=None, set_consttime_flag=True):
|
||||
def bn_from_int(py_int: int, check_modulus=None, set_consttime_flag=True):
|
||||
"""
|
||||
Converts the given Python int to an OpenSSL BIGNUM. If ``modulus`` is
|
||||
provided, it will check if the Python integer is within ``[0, modulus)``.
|
||||
|
@ -155,15 +155,15 @@ def bn_from_int(py_int: int, modulus=None, set_consttime_flag=True):
|
|||
conv_bn = backend._int_to_bn(py_int)
|
||||
conv_bn = backend._ffi.gc(conv_bn, backend._lib.BN_clear_free)
|
||||
|
||||
if modulus and not bn_is_normalized(conv_bn, modulus):
|
||||
raise ValueError("The Python integer given is not under the provided modulus.")
|
||||
if check_modulus and not bn_is_normalized(conv_bn, check_modulus):
|
||||
raise ValueError(f"The Python integer given ({py_int}) is not under the provided modulus.")
|
||||
|
||||
if set_consttime_flag:
|
||||
backend._lib.BN_set_flags(conv_bn, backend._lib.BN_FLG_CONSTTIME)
|
||||
return conv_bn
|
||||
|
||||
|
||||
def bn_from_bytes(bytes_seq: bytes, set_consttime_flag=True, modulus=None):
|
||||
def bn_from_bytes(bytes_seq: bytes, set_consttime_flag=True, check_modulus=None, apply_modulus=None):
|
||||
"""
|
||||
Converts the given byte sequence to an OpenSSL BIGNUM.
|
||||
If set_consttime_flag is set to True, OpenSSL will use constant time
|
||||
|
@ -173,10 +173,14 @@ def bn_from_bytes(bytes_seq: bytes, set_consttime_flag=True, modulus=None):
|
|||
backend._lib.BN_bin2bn(bytes_seq, len(bytes_seq), bn)
|
||||
backend.openssl_assert(bn != backend._ffi.NULL)
|
||||
|
||||
if modulus:
|
||||
if check_modulus and not bn_is_normalized(bn, check_modulus):
|
||||
raise ValueError(f"The integer encoded with given bytes ({repr(bytes_seq)}) "
|
||||
"is not under the provided modulus.")
|
||||
|
||||
if apply_modulus:
|
||||
bignum =_bn_new()
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
res = backend._lib.BN_mod(bignum, bn, modulus, bn_ctx)
|
||||
res = backend._lib.BN_mod(bignum, bn, apply_modulus, bn_ctx)
|
||||
backend.openssl_assert(res == 1)
|
||||
|
||||
return bn
|
||||
|
|
Loading…
Reference in New Issue