curve_scalar: don't check range in __init__, only in publicly used constructors

pull/263/head
Bogdan Opanchuk 2021-03-19 20:18:28 -07:00
parent d65969761c
commit f58a2580dc
3 changed files with 23 additions and 17 deletions

View File

@ -17,9 +17,6 @@ class CurveScalar(Serializable):
"""
def __init__(self, backend_bignum):
if not openssl.bn_is_normalized(backend_bignum, CURVE.bn_order):
raise ValueError("The provided BIGNUM is not on the provided curve.")
self._backend_bignum = backend_bignum
@classmethod
@ -30,11 +27,12 @@ class CurveScalar(Serializable):
return cls(openssl.bn_random_nonzero(CURVE.bn_order))
@classmethod
def from_int(cls, num: int) -> 'CurveScalar':
def from_int(cls, num: int, check_normalization: bool = True) -> 'CurveScalar':
"""
Returns a CurveScalar object from a given integer on a curve.
"""
conv_bn = openssl.bn_from_int(num, modulus=CURVE.bn_order)
modulus = CURVE.bn_order if check_normalization else None
conv_bn = openssl.bn_from_int(num, check_modulus=modulus)
return cls(conv_bn)
@classmethod
@ -43,13 +41,13 @@ class CurveScalar(Serializable):
# Currently just matching what we have in RustCrypto stack
# (taking bytes modulo curve order).
# Can produce zeros!
bn = openssl.bn_from_bytes(digest.finalize(), modulus=CURVE.bn_order)
bn = openssl.bn_from_bytes(digest.finalize(), apply_modulus=CURVE.bn_order)
return cls(bn)
@classmethod
def __take__(cls, data: bytes) -> Tuple['CurveScalar', bytes]:
scalar_data, data = cls.__take_bytes__(data, CURVE.scalar_size)
bignum = openssl.bn_from_bytes(scalar_data)
bignum = openssl.bn_from_bytes(scalar_data, check_modulus=CURVE.bn_order)
return cls(bignum), data
def __bytes__(self) -> bytes:

View File

@ -57,14 +57,18 @@ class SecretKey(Serializable):
backend_sk = openssl.bn_to_privkey(CURVE, self._scalar_key._backend_bignum)
signature_der_bytes = backend_sk.sign(message, signature_algorithm)
r, s = utils.decode_dss_signature(signature_der_bytes)
r_int, s_int = utils.decode_dss_signature(signature_der_bytes)
# Normalize s
# s is public, so no constant-timeness required here
if s > (CURVE.order >> 1):
s = CURVE.order - s
if s_int > (CURVE.order >> 1):
s_int = CURVE.order - s_int
return Signature(CurveScalar.from_int(r), CurveScalar.from_int(s))
# Already normalized, don't waste time
r = CurveScalar.from_int(r_int, check_normalization=False)
s = CurveScalar.from_int(s_int, check_normalization=False)
return Signature(r, s)
class Signature(Serializable):

View File

@ -144,7 +144,7 @@ def bn_is_normalized(check_bn, modulus):
return (check_sign == 1 or check_sign == 0) and range_check == -1
def bn_from_int(py_int: int, modulus=None, set_consttime_flag=True):
def bn_from_int(py_int: int, check_modulus=None, set_consttime_flag=True):
"""
Converts the given Python int to an OpenSSL BIGNUM. If ``modulus`` is
provided, it will check if the Python integer is within ``[0, modulus)``.
@ -155,15 +155,15 @@ def bn_from_int(py_int: int, modulus=None, set_consttime_flag=True):
conv_bn = backend._int_to_bn(py_int)
conv_bn = backend._ffi.gc(conv_bn, backend._lib.BN_clear_free)
if modulus and not bn_is_normalized(conv_bn, modulus):
raise ValueError("The Python integer given is not under the provided modulus.")
if check_modulus and not bn_is_normalized(conv_bn, check_modulus):
raise ValueError(f"The Python integer given ({py_int}) is not under the provided modulus.")
if set_consttime_flag:
backend._lib.BN_set_flags(conv_bn, backend._lib.BN_FLG_CONSTTIME)
return conv_bn
def bn_from_bytes(bytes_seq: bytes, set_consttime_flag=True, modulus=None):
def bn_from_bytes(bytes_seq: bytes, set_consttime_flag=True, check_modulus=None, apply_modulus=None):
"""
Converts the given byte sequence to an OpenSSL BIGNUM.
If set_consttime_flag is set to True, OpenSSL will use constant time
@ -173,10 +173,14 @@ def bn_from_bytes(bytes_seq: bytes, set_consttime_flag=True, modulus=None):
backend._lib.BN_bin2bn(bytes_seq, len(bytes_seq), bn)
backend.openssl_assert(bn != backend._ffi.NULL)
if modulus:
if check_modulus and not bn_is_normalized(bn, check_modulus):
raise ValueError(f"The integer encoded with given bytes ({repr(bytes_seq)}) "
"is not under the provided modulus.")
if apply_modulus:
bignum =_bn_new()
with backend._tmp_bn_ctx() as bn_ctx:
res = backend._lib.BN_mod(bignum, bn, modulus, bn_ctx)
res = backend._lib.BN_mod(bignum, bn, apply_modulus, bn_ctx)
backend.openssl_assert(res == 1)
return bn