mirror of https://github.com/nucypher/pyUmbral.git
Move all OpenSSL stuff into one module, move around some low-level details
parent
b96888cafb
commit
d532ef1383
115
umbral/curve.py
115
umbral/curve.py
|
@ -1,119 +1,10 @@
|
|||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
from . import openssl
|
||||
|
||||
|
||||
class Curve:
|
||||
"""
|
||||
Acts as a container to store constant variables such as the OpenSSL
|
||||
curve_nid, the EC_GROUP struct, and the order of the curve.
|
||||
|
||||
Contains a whitelist of supported elliptic curves used in pyUmbral.
|
||||
|
||||
"""
|
||||
|
||||
_supported_curves = {
|
||||
415: 'secp256r1',
|
||||
714: 'secp256k1',
|
||||
715: 'secp384r1'
|
||||
}
|
||||
|
||||
def __init__(self, nid: int) -> None:
|
||||
"""
|
||||
Instantiates an OpenSSL curve with the provided curve_nid and derives
|
||||
the proper EC_GROUP struct and order. You can _only_ instantiate curves
|
||||
with supported nids (see `Curve.supported_curves`).
|
||||
"""
|
||||
|
||||
try:
|
||||
self.__curve_name = self._supported_curves[nid]
|
||||
except KeyError:
|
||||
raise NotImplementedError("Curve NID {} is not supported.".format(nid))
|
||||
|
||||
# set only once
|
||||
self.__curve_nid = nid
|
||||
self.__ec_group = openssl._get_ec_group_by_curve_nid(self.__curve_nid)
|
||||
self.__order = openssl._get_ec_order_by_group(self.ec_group)
|
||||
self.__generator = openssl._get_ec_generator_by_group(self.ec_group)
|
||||
|
||||
# Init cache
|
||||
self.__field_order_size_in_bytes = 0
|
||||
self.__group_order_size_in_bytes = 0
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, name: str) -> 'Curve':
|
||||
"""
|
||||
Alternate constructor to generate a curve instance by its name.
|
||||
|
||||
Raises NotImplementedError if the name cannot be mapped to a known
|
||||
supported curve NID.
|
||||
|
||||
"""
|
||||
|
||||
name = name.casefold() # normalize
|
||||
|
||||
for supported_nid, supported_name in cls._supported_curves.items():
|
||||
if name == supported_name:
|
||||
instance = cls(nid=supported_nid)
|
||||
break
|
||||
else:
|
||||
message = "{} is not supported curve name.".format(name)
|
||||
raise NotImplementedError(message)
|
||||
|
||||
return instance
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.__curve_nid == other.curve_nid
|
||||
|
||||
def __str__(self):
|
||||
return "<OpenSSL Curve(nid={}, name={})>".format(self.__curve_nid, self.__curve_name)
|
||||
|
||||
#
|
||||
# Immutable Curve Data
|
||||
#
|
||||
|
||||
@property
|
||||
def field_order_size_in_bytes(self) -> int:
|
||||
if not self.__field_order_size_in_bytes:
|
||||
size_in_bits = openssl._get_ec_group_degree(self.__ec_group)
|
||||
self.__field_order_size_in_bytes = (size_in_bits + 7) // 8
|
||||
return self.__field_order_size_in_bytes
|
||||
|
||||
@property
|
||||
def group_order_size_in_bytes(self) -> int:
|
||||
if not self.__group_order_size_in_bytes:
|
||||
BN_num_bytes = default_backend()._lib.BN_num_bytes
|
||||
self.__group_order_size_in_bytes = BN_num_bytes(self.order)
|
||||
return self.__group_order_size_in_bytes
|
||||
|
||||
@property
|
||||
def curve_nid(self) -> int:
|
||||
return self.__curve_nid
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.__curve_name
|
||||
|
||||
@property
|
||||
def ec_group(self):
|
||||
return self.__ec_group
|
||||
|
||||
@property
|
||||
def order(self):
|
||||
return self.__order
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return self.__generator
|
||||
|
||||
|
||||
#
|
||||
# Global Curve Instances
|
||||
#
|
||||
|
||||
SECP256R1 = Curve.from_name('secp256r1')
|
||||
SECP256K1 = Curve.from_name('secp256k1')
|
||||
SECP384R1 = Curve.from_name('secp384r1')
|
||||
SECP256R1 = openssl.Curve.from_name('secp256r1')
|
||||
SECP256K1 = openssl.Curve.from_name('secp256k1')
|
||||
SECP384R1 = openssl.Curve.from_name('secp384r1')
|
||||
|
||||
CURVES = (SECP256K1, SECP256R1, SECP384R1)
|
||||
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
from typing import Optional, Tuple
|
||||
|
||||
from cryptography.hazmat.backends.openssl import backend
|
||||
|
||||
from . import openssl
|
||||
from .curve import CURVE
|
||||
from .curve_scalar import CurveScalar
|
||||
|
@ -18,7 +16,7 @@ class CurvePoint(Serializable):
|
|||
|
||||
@classmethod
|
||||
def generator(cls) -> 'CurvePoint':
|
||||
return cls(CURVE.generator)
|
||||
return cls(CURVE.point_generator)
|
||||
|
||||
@classmethod
|
||||
def random(cls) -> 'CurvePoint':
|
||||
|
@ -29,101 +27,54 @@ class CurvePoint(Serializable):
|
|||
return cls.generator() * CurveScalar.random_nonzero()
|
||||
|
||||
@classmethod
|
||||
def from_affine(cls, coords: Tuple[int, int]) -> 'CurvePoint':
|
||||
def from_affine(cls, affine_x: int, affine_y: int) -> 'CurvePoint':
|
||||
"""
|
||||
Returns a CurvePoint object from the given affine coordinates in a tuple in
|
||||
the format of (x, y) and a given curve.
|
||||
"""
|
||||
affine_x, affine_y = coords
|
||||
if type(affine_x) == int:
|
||||
affine_x = openssl._int_to_bn(affine_x, curve=None)
|
||||
|
||||
if type(affine_y) == int:
|
||||
affine_y = openssl._int_to_bn(affine_y, curve=None)
|
||||
|
||||
backend_point = openssl._get_EC_POINT_via_affine(affine_x, affine_y, CURVE)
|
||||
backend_point = openssl.point_from_affine_coords(CURVE, affine_x, affine_y)
|
||||
return cls(backend_point)
|
||||
|
||||
def to_affine(self):
|
||||
def to_affine(self) -> Tuple[int, int]:
|
||||
"""
|
||||
Returns a tuple of Python ints in the format of (x, y) that represents
|
||||
the point in the curve.
|
||||
"""
|
||||
affine_x, affine_y = openssl._get_affine_coords_via_EC_POINT(self._backend_point, CURVE)
|
||||
return (backend._bn_to_int(affine_x), backend._bn_to_int(affine_y))
|
||||
return openssl.point_to_affine_coords(CURVE, self._backend_point)
|
||||
|
||||
@classmethod
|
||||
def __take__(cls, data: bytes) -> Tuple['CurvePoint', bytes]:
|
||||
"""
|
||||
Returns a CurvePoint object from the given byte data on the curve provided.
|
||||
"""
|
||||
size = CURVE.field_order_size_in_bytes + 1 # compressed point size
|
||||
size = CURVE.field_element_size + 1 # compressed point size
|
||||
point_data, data = cls.__take_bytes__(data, size)
|
||||
|
||||
point = openssl._get_new_EC_POINT(CURVE)
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
res = backend._lib.EC_POINT_oct2point(
|
||||
CURVE.ec_group, point, point_data, len(point_data), bn_ctx);
|
||||
backend.openssl_assert(res == 1)
|
||||
|
||||
point = openssl.point_from_bytes(CURVE, point_data)
|
||||
return cls(point), data
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
"""
|
||||
Returns the CurvePoint serialized as bytes in the compressed form.
|
||||
"""
|
||||
point_conversion_form = backend._lib.POINT_CONVERSION_COMPRESSED
|
||||
size = CURVE.field_order_size_in_bytes + 1 # compressed point size
|
||||
|
||||
bin_ptr = backend._ffi.new("unsigned char[]", size)
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
bin_len = backend._lib.EC_POINT_point2oct(
|
||||
CURVE.ec_group, self._backend_point, point_conversion_form,
|
||||
bin_ptr, size, bn_ctx
|
||||
)
|
||||
backend.openssl_assert(bin_len != 0)
|
||||
|
||||
return bytes(backend._ffi.buffer(bin_ptr, bin_len)[:])
|
||||
return openssl.point_to_bytes_compressed(CURVE, self._backend_point)
|
||||
|
||||
def __eq__(self, other):
|
||||
"""
|
||||
Compares two EC_POINTS for equality.
|
||||
"""
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
is_equal = backend._lib.EC_POINT_cmp(
|
||||
CURVE.ec_group, self._backend_point, other._backend_point, bn_ctx
|
||||
)
|
||||
backend.openssl_assert(is_equal != -1)
|
||||
|
||||
# 1 is not-equal, 0 is equal, -1 is error
|
||||
return not bool(is_equal)
|
||||
return openssl.point_eq(CURVE, self._backend_point, other._backend_point)
|
||||
|
||||
def __mul__(self, other: CurveScalar) -> 'CurvePoint':
|
||||
"""
|
||||
Performs an EC_POINT_mul on an EC_POINT and a BIGNUM.
|
||||
"""
|
||||
# TODO: Check that both points use the same curve.
|
||||
prod = openssl._get_new_EC_POINT(CURVE)
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
res = backend._lib.EC_POINT_mul(
|
||||
CURVE.ec_group, prod, backend._ffi.NULL,
|
||||
self._backend_point, other._backend_bignum, bn_ctx
|
||||
)
|
||||
backend.openssl_assert(res == 1)
|
||||
|
||||
return CurvePoint(prod)
|
||||
return CurvePoint(openssl.point_mul_bn(CURVE, self._backend_point, other._backend_bignum))
|
||||
|
||||
def __add__(self, other: 'CurvePoint') -> 'CurvePoint':
|
||||
"""
|
||||
Performs an EC_POINT_add on two EC_POINTS.
|
||||
"""
|
||||
op_sum = openssl._get_new_EC_POINT(CURVE)
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
res = backend._lib.EC_POINT_add(
|
||||
CURVE.ec_group, op_sum, self._backend_point, other._backend_point, bn_ctx
|
||||
)
|
||||
backend.openssl_assert(res == 1)
|
||||
return CurvePoint(op_sum)
|
||||
return CurvePoint(openssl.point_add(CURVE, self._backend_point, other._backend_point))
|
||||
|
||||
def __sub__(self, other: 'CurvePoint') -> 'CurvePoint':
|
||||
"""
|
||||
|
@ -136,13 +87,4 @@ class CurvePoint(Serializable):
|
|||
Computes the additive inverse of a CurvePoint, by performing an
|
||||
EC_POINT_invert on itself.
|
||||
"""
|
||||
inv = backend._lib.EC_POINT_dup(self._backend_point, CURVE.ec_group)
|
||||
backend.openssl_assert(inv != backend._ffi.NULL)
|
||||
inv = backend._ffi.gc(inv, backend._lib.EC_POINT_clear_free)
|
||||
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
res = backend._lib.EC_POINT_invert(
|
||||
CURVE.ec_group, inv, bn_ctx
|
||||
)
|
||||
backend.openssl_assert(res == 1)
|
||||
return CurvePoint(inv)
|
||||
return CurvePoint(openssl.point_neg(CURVE, self._backend_point))
|
||||
|
|
|
@ -1,23 +1,23 @@
|
|||
from typing import Optional, Union, Tuple
|
||||
|
||||
from cryptography.hazmat.backends.openssl import backend
|
||||
from typing import TYPE_CHECKING, Optional, Union, Tuple
|
||||
|
||||
from . import openssl
|
||||
from .curve import CURVE
|
||||
from .serializable import Serializable
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from .hashing import Hash
|
||||
|
||||
|
||||
class CurveScalar(Serializable):
|
||||
"""
|
||||
Represents an OpenSSL Bignum modulo the order of a curve. Some of these
|
||||
operations will only work with prime numbers.
|
||||
|
||||
By default, the underlying OpenSSL BIGNUM has BN_FLG_CONSTTIME set for
|
||||
constant time operations.
|
||||
"""
|
||||
|
||||
def __init__(self, backend_bignum):
|
||||
on_curve = openssl._bn_is_on_curve(backend_bignum, CURVE)
|
||||
if not on_curve:
|
||||
if not openssl.bn_is_normalized(backend_bignum, CURVE.bn_order):
|
||||
raise ValueError("The provided BIGNUM is not on the provided curve.")
|
||||
|
||||
self._backend_bignum = backend_bignum
|
||||
|
@ -27,57 +27,42 @@ class CurveScalar(Serializable):
|
|||
"""
|
||||
Returns a CurveScalar object with a cryptographically secure OpenSSL BIGNUM.
|
||||
"""
|
||||
one = cls.one()._backend_bignum
|
||||
|
||||
# TODO: in most cases, we want this number to be secret.
|
||||
# OpenSSL 1.1.1 has `BN_priv_rand_range()`, but it is not
|
||||
# currently exported by `cryptography`.
|
||||
# Use when available.
|
||||
|
||||
# Calculate `order - 1`
|
||||
order_minus_1 = openssl._get_new_BN()
|
||||
res = backend._lib.BN_sub(order_minus_1, CURVE.order, one)
|
||||
backend.openssl_assert(res == 1)
|
||||
|
||||
# Get a random in range `[0, order - 1)`
|
||||
new_rand_bn = openssl._get_new_BN()
|
||||
res = backend._lib.BN_rand_range(new_rand_bn, order_minus_1)
|
||||
backend.openssl_assert(res == 1)
|
||||
|
||||
# Turn it into a random in range `[1, order)`
|
||||
op_sum = openssl._get_new_BN()
|
||||
res = backend._lib.BN_add(op_sum, new_rand_bn, one)
|
||||
backend.openssl_assert(res == 1)
|
||||
|
||||
return cls(op_sum)
|
||||
return cls(openssl.bn_random_nonzero(CURVE.bn_order))
|
||||
|
||||
@classmethod
|
||||
def from_int(cls, num: int) -> 'CurveScalar':
|
||||
"""
|
||||
Returns a CurveScalar object from a given integer on a curve.
|
||||
"""
|
||||
conv_bn = openssl._int_to_bn(num, CURVE)
|
||||
conv_bn = openssl.bn_from_int(num, modulus=CURVE.bn_order)
|
||||
return cls(conv_bn)
|
||||
|
||||
@classmethod
|
||||
def from_digest(cls, digest: 'Hash') -> 'CurveScalar':
|
||||
# TODO (#39): to be replaced by the standard algroithm.
|
||||
# Currently just matching what we have in RustCrypto stack
|
||||
# (taking bytes modulo curve order).
|
||||
# Can produce zeros!
|
||||
bn = openssl.bn_from_bytes(digest.finalize(), modulus=CURVE.bn_order)
|
||||
return cls(bn)
|
||||
|
||||
@classmethod
|
||||
def __take__(cls, data: bytes) -> Tuple['CurveScalar', bytes]:
|
||||
size = backend._lib.BN_num_bytes(CURVE.order)
|
||||
scalar_data, data = cls.__take_bytes__(data, size)
|
||||
bignum = openssl._bytes_to_bn(scalar_data)
|
||||
scalar_data, data = cls.__take_bytes__(data, CURVE.scalar_size)
|
||||
bignum = openssl.bn_from_bytes(scalar_data)
|
||||
return cls(bignum), data
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
"""
|
||||
Returns the CurveScalar as bytes.
|
||||
"""
|
||||
size = backend._lib.BN_num_bytes(CURVE.order)
|
||||
return openssl._bn_to_bytes(self._backend_bignum, size)
|
||||
return openssl.bn_to_bytes(self._backend_bignum, CURVE.scalar_size)
|
||||
|
||||
def __int__(self) -> int:
|
||||
"""
|
||||
Converts the CurveScalar to a Python int.
|
||||
"""
|
||||
return backend._bn_to_int(self._backend_bignum)
|
||||
return openssl.bn_to_int(self._backend_bignum)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
"""
|
||||
|
@ -85,17 +70,14 @@ class CurveScalar(Serializable):
|
|||
"""
|
||||
if type(other) == int:
|
||||
other = CurveScalar.from_int(other)
|
||||
|
||||
# -1 less than, 0 is equal to, 1 is greater than
|
||||
return not bool(backend._lib.BN_cmp(self._backend_bignum, other._backend_bignum))
|
||||
return openssl.bn_cmp(self._backend_bignum, other._backend_bignum) == 0
|
||||
|
||||
@classmethod
|
||||
def one(cls):
|
||||
return cls(backend._lib.BN_value_one())
|
||||
return cls(openssl.bn_one())
|
||||
|
||||
def is_zero(self):
|
||||
# BN_is_zero() is not exported, so this will have to do
|
||||
return self == 0
|
||||
return openssl.bn_is_zero(self._backend_bignum)
|
||||
|
||||
def __mul__(self, other: Union[int, 'CurveScalar']) -> 'CurveScalar':
|
||||
"""
|
||||
|
@ -103,15 +85,7 @@ class CurveScalar(Serializable):
|
|||
"""
|
||||
if isinstance(other, int):
|
||||
other = CurveScalar.from_int(other)
|
||||
|
||||
product = openssl._get_new_BN()
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
res = backend._lib.BN_mod_mul(
|
||||
product, self._backend_bignum, other._backend_bignum, CURVE.order, bn_ctx
|
||||
)
|
||||
backend.openssl_assert(res == 1)
|
||||
|
||||
return CurveScalar(product)
|
||||
return CurveScalar(openssl.bn_mul(self._backend_bignum, other._backend_bignum, CURVE.bn_order))
|
||||
|
||||
def __add__(self, other : Union[int, 'CurveScalar']) -> 'CurveScalar':
|
||||
"""
|
||||
|
@ -119,15 +93,7 @@ class CurveScalar(Serializable):
|
|||
"""
|
||||
if isinstance(other, int):
|
||||
other = CurveScalar.from_int(other)
|
||||
|
||||
op_sum = openssl._get_new_BN()
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
res = backend._lib.BN_mod_add(
|
||||
op_sum, self._backend_bignum, other._backend_bignum, CURVE.order, bn_ctx
|
||||
)
|
||||
backend.openssl_assert(res == 1)
|
||||
|
||||
return CurveScalar(op_sum)
|
||||
return CurveScalar(openssl.bn_add(self._backend_bignum, other._backend_bignum, CURVE.bn_order))
|
||||
|
||||
def __sub__(self, other : Union[int, 'CurveScalar']) -> 'CurveScalar':
|
||||
"""
|
||||
|
@ -135,26 +101,11 @@ class CurveScalar(Serializable):
|
|||
"""
|
||||
if isinstance(other, int):
|
||||
other = CurveScalar.from_int(other)
|
||||
|
||||
diff = openssl._get_new_BN()
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
res = backend._lib.BN_mod_sub(
|
||||
diff, self._backend_bignum, other._backend_bignum, CURVE.order, bn_ctx
|
||||
)
|
||||
backend.openssl_assert(res == 1)
|
||||
|
||||
return CurveScalar(diff)
|
||||
return CurveScalar(openssl.bn_sub(self._backend_bignum, other._backend_bignum, CURVE.bn_order))
|
||||
|
||||
def invert(self) -> 'CurveScalar':
|
||||
"""
|
||||
Performs a BN_mod_inverse.
|
||||
WARNING: Only in constant time if BN_FLG_CONSTTIME is set on the BN.
|
||||
"""
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
inv = backend._lib.BN_mod_inverse(
|
||||
backend._ffi.NULL, self._backend_bignum, CURVE.order, bn_ctx
|
||||
)
|
||||
backend.openssl_assert(inv != backend._ffi.NULL)
|
||||
inv = backend._ffi.gc(inv, backend._lib.BN_clear_free)
|
||||
|
||||
return CurveScalar(inv)
|
||||
return CurveScalar(openssl.bn_invert(self._backend_bignum, CURVE.bn_order))
|
||||
|
|
|
@ -2,7 +2,6 @@ import os
|
|||
from typing import Optional
|
||||
|
||||
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
||||
from cryptography.hazmat.backends.openssl import backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
|
||||
from nacl.bindings.crypto_aead import (
|
||||
|
@ -12,6 +11,8 @@ from nacl.bindings.crypto_aead import (
|
|||
crypto_aead_xchacha20poly1305_ietf_NPUBBYTES as XCHACHA_NONCE_SIZE,
|
||||
)
|
||||
|
||||
from . import openssl
|
||||
|
||||
|
||||
def kdf(data: bytes,
|
||||
key_length: int,
|
||||
|
@ -23,7 +24,7 @@ def kdf(data: bytes,
|
|||
length=key_length,
|
||||
salt=salt,
|
||||
info=info,
|
||||
backend=backend)
|
||||
backend=openssl.backend)
|
||||
return hkdf.derive(data)
|
||||
|
||||
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
from typing import TYPE_CHECKING, Optional, Type, Iterable
|
||||
|
||||
from cryptography.hazmat.backends.openssl import backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.exceptions import InternalError
|
||||
|
||||
from . import openssl
|
||||
from .openssl import backend, ErrorInvalidCompressedPoint
|
||||
from .curve import CURVE
|
||||
from .curve_scalar import CurveScalar
|
||||
from .curve_point import CurvePoint
|
||||
|
@ -19,30 +17,17 @@ class Hash:
|
|||
OUTPUT_SIZE = 32
|
||||
|
||||
def __init__(self, dst: bytes):
|
||||
self._sha256 = hashes.Hash(hashes.SHA256(), backend=backend)
|
||||
self._backend_hash_algorithm = hashes.SHA256()
|
||||
self._hash = hashes.Hash(self._backend_hash_algorithm, backend=backend)
|
||||
|
||||
len_dst = len(dst).to_bytes(4, byteorder='big')
|
||||
self.update(len_dst + dst)
|
||||
|
||||
def update(self, data: bytes) -> None:
|
||||
self._sha256.update(data)
|
||||
self._hash.update(data)
|
||||
|
||||
def finalize(self) -> bytes:
|
||||
return self._sha256.finalize()
|
||||
|
||||
|
||||
def digest_to_scalar(digest: Hash) -> CurveScalar:
|
||||
# TODO: to be replaced by the standard algroithm.
|
||||
# Currently just matching what we have in RustCrypto stack.
|
||||
# Can produce zeros!
|
||||
|
||||
hash_digest = openssl._bytes_to_bn(digest.finalize())
|
||||
|
||||
bignum = openssl._get_new_BN()
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
res = backend._lib.BN_mod(bignum, hash_digest, CURVE.order, bn_ctx)
|
||||
backend.openssl_assert(res == 1)
|
||||
|
||||
return CurveScalar(bignum)
|
||||
return self._hash.finalize()
|
||||
|
||||
|
||||
def hash_to_polynomial_arg(precursor: CurvePoint,
|
||||
|
@ -55,14 +40,14 @@ def hash_to_polynomial_arg(precursor: CurvePoint,
|
|||
digest.update(bytes(pubkey))
|
||||
digest.update(bytes(dh_point))
|
||||
digest.update(bytes(kfrag_id))
|
||||
return digest_to_scalar(digest)
|
||||
return CurveScalar.from_digest(digest)
|
||||
|
||||
|
||||
def hash_capsule_points(e: CurvePoint, v: CurvePoint) -> CurveScalar:
|
||||
digest = Hash(b"CAPSULE_POINTS")
|
||||
digest.update(bytes(e))
|
||||
digest.update(bytes(v))
|
||||
return digest_to_scalar(digest)
|
||||
return CurveScalar.from_digest(digest)
|
||||
|
||||
|
||||
def hash_to_shared_secret(precursor: CurvePoint,
|
||||
|
@ -73,7 +58,7 @@ def hash_to_shared_secret(precursor: CurvePoint,
|
|||
digest.update(bytes(precursor))
|
||||
digest.update(bytes(pubkey))
|
||||
digest.update(bytes(dh_point))
|
||||
return digest_to_scalar(digest)
|
||||
return CurveScalar.from_digest(digest)
|
||||
|
||||
|
||||
|
||||
|
@ -83,7 +68,7 @@ def hash_to_cfrag_verification(points: Iterable[CurvePoint], metadata: Optional[
|
|||
digest.update(bytes(point))
|
||||
if metadata is not None:
|
||||
digest.update(metadata)
|
||||
return digest_to_scalar(digest)
|
||||
return CurveScalar.from_digest(digest)
|
||||
|
||||
|
||||
def hash_to_cfrag_signature(kfrag_id: 'KeyFragID',
|
||||
|
@ -122,10 +107,10 @@ class SignatureDigest:
|
|||
self._digest.update(value)
|
||||
|
||||
def sign(self, sk: SecretKey) -> Signature:
|
||||
return sk.sign_digest(self._digest, hashes.SHA256)
|
||||
return sk.sign_digest(self._digest)
|
||||
|
||||
def verify(self, pk: PublicKey, sig: Signature):
|
||||
return sig.verify_digest(pk, self._digest, hashes.SHA256)
|
||||
return sig.verify_digest(pk, self._digest)
|
||||
|
||||
|
||||
def unsafe_hash_to_point(dst: bytes, data: bytes) -> CurvePoint:
|
||||
|
@ -146,22 +131,15 @@ def unsafe_hash_to_point(dst: bytes, data: bytes) -> CurvePoint:
|
|||
ibytes = i.to_bytes(4, byteorder='big')
|
||||
digest = Hash(dst)
|
||||
digest.update(data_with_len + ibytes)
|
||||
point_data = digest.finalize()[:CURVE.field_order_size_in_bytes]
|
||||
point_data = digest.finalize()[:CURVE.field_element_size]
|
||||
|
||||
compressed_point = sign + point_data
|
||||
|
||||
try:
|
||||
return CurvePoint.from_bytes(compressed_point)
|
||||
except InternalError as e:
|
||||
# We want to catch specific InternalExceptions:
|
||||
# - Point not in the curve (code 107)
|
||||
# - Invalid compressed point (code 110)
|
||||
# https://github.com/openssl/openssl/blob/master/include/openssl/ecerr.h#L228
|
||||
if e.err_code[0].reason in (107, 110):
|
||||
pass
|
||||
else:
|
||||
# Any other exception, we raise it
|
||||
raise e
|
||||
except ErrorInvalidCompressedPoint:
|
||||
# If it is not a valid point, continue on
|
||||
pass
|
||||
|
||||
# Only happens with probability 2^(-32)
|
||||
raise ValueError('Could not hash input into the curve') # pragma: no cover
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.backends.openssl import backend
|
||||
from cryptography.hazmat.backends.openssl.ec import _EllipticCurvePrivateKey, _EllipticCurvePublicKey
|
||||
from cryptography.hazmat.primitives.asymmetric import utils
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import ECDSA
|
||||
|
||||
|
@ -51,51 +49,19 @@ class SecretKey(Serializable):
|
|||
def __bytes__(self) -> bytes:
|
||||
return bytes(self._scalar_key)
|
||||
|
||||
def to_cryptography_privkey(self) -> _EllipticCurvePrivateKey:
|
||||
"""
|
||||
Returns a cryptography.io EllipticCurvePrivateKey from the Umbral key.
|
||||
"""
|
||||
ec_key = backend._lib.EC_KEY_new()
|
||||
backend.openssl_assert(ec_key != backend._ffi.NULL)
|
||||
ec_key = backend._ffi.gc(ec_key, backend._lib.EC_KEY_free)
|
||||
def sign_digest(self, digest: 'Hash') -> 'Signature':
|
||||
|
||||
set_group_result = backend._lib.EC_KEY_set_group(ec_key, CURVE.ec_group)
|
||||
backend.openssl_assert(set_group_result == 1)
|
||||
|
||||
set_privkey_result = backend._lib.EC_KEY_set_private_key(
|
||||
ec_key, self._scalar_key._backend_bignum
|
||||
)
|
||||
backend.openssl_assert(set_privkey_result == 1)
|
||||
|
||||
# Get public key
|
||||
point = openssl._get_new_EC_POINT(CURVE)
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
mult_result = backend._lib.EC_POINT_mul(
|
||||
CURVE.ec_group, point, self._scalar_key._backend_bignum,
|
||||
backend._ffi.NULL, backend._ffi.NULL, bn_ctx
|
||||
)
|
||||
backend.openssl_assert(mult_result == 1)
|
||||
|
||||
set_pubkey_result = backend._lib.EC_KEY_set_public_key(ec_key, point)
|
||||
backend.openssl_assert(set_pubkey_result == 1)
|
||||
|
||||
evp_pkey = backend._ec_cdata_to_evp_pkey(ec_key)
|
||||
return _EllipticCurvePrivateKey(backend, ec_key, evp_pkey)
|
||||
|
||||
def sign_digest(self, digest: 'Hash', backend_hash_algorithm) -> 'Signature':
|
||||
|
||||
signature_algorithm = ECDSA(utils.Prehashed(backend_hash_algorithm()))
|
||||
signature_algorithm = ECDSA(utils.Prehashed(digest._backend_hash_algorithm))
|
||||
message = digest.finalize()
|
||||
|
||||
cpk = self.to_cryptography_privkey()
|
||||
signature_der_bytes = cpk.sign(message, signature_algorithm)
|
||||
backend_sk = openssl.bn_to_privkey(CURVE, self._scalar_key._backend_bignum)
|
||||
signature_der_bytes = backend_sk.sign(message, signature_algorithm)
|
||||
r, s = utils.decode_dss_signature(signature_der_bytes)
|
||||
|
||||
# Normalize s
|
||||
# s is public, so no constant-timeness required here
|
||||
order = backend._bn_to_int(CURVE.order)
|
||||
if s > (order >> 1):
|
||||
s = order - s
|
||||
if s > (CURVE.order >> 1):
|
||||
s = CURVE.order - s
|
||||
|
||||
return Signature(CurveScalar.from_int(r), CurveScalar.from_int(s))
|
||||
|
||||
|
@ -114,17 +80,18 @@ class Signature(Serializable):
|
|||
def __repr__(self):
|
||||
return f"ECDSA Signature: {bytes(self).hex()[:15]}"
|
||||
|
||||
def verify_digest(self, verifying_key: 'PublicKey', digest: 'Hash', backend_hash_algorithm) -> bool:
|
||||
cryptography_pub_key = verifying_key.to_cryptography_pubkey()
|
||||
signature_algorithm = ECDSA(utils.Prehashed(backend_hash_algorithm()))
|
||||
def verify_digest(self, verifying_key: 'PublicKey', digest: 'Hash') -> bool:
|
||||
backend_pk = openssl.point_to_pubkey(CURVE, verifying_key.point()._backend_point)
|
||||
signature_algorithm = ECDSA(utils.Prehashed(digest._backend_hash_algorithm))
|
||||
|
||||
message = digest.finalize()
|
||||
signature_der_bytes = utils.encode_dss_signature(int(self.r), int(self.s))
|
||||
|
||||
# TODO: Raise error instead of returning boolean
|
||||
try:
|
||||
cryptography_pub_key.verify(signature=signature_der_bytes,
|
||||
data=message,
|
||||
signature_algorithm=signature_algorithm)
|
||||
backend_pk.verify(signature=signature_der_bytes,
|
||||
data=message,
|
||||
signature_algorithm=signature_algorithm)
|
||||
except InvalidSignature:
|
||||
return False
|
||||
return True
|
||||
|
@ -161,25 +128,6 @@ class PublicKey(Serializable):
|
|||
def __bytes__(self) -> bytes:
|
||||
return bytes(self._point_key)
|
||||
|
||||
def to_cryptography_pubkey(self) -> _EllipticCurvePublicKey:
|
||||
"""
|
||||
Returns a cryptography.io EllipticCurvePublicKey from the Umbral key.
|
||||
"""
|
||||
ec_key = backend._lib.EC_KEY_new()
|
||||
backend.openssl_assert(ec_key != backend._ffi.NULL)
|
||||
ec_key = backend._ffi.gc(ec_key, backend._lib.EC_KEY_free)
|
||||
|
||||
set_group_result = backend._lib.EC_KEY_set_group(ec_key, CURVE.ec_group)
|
||||
backend.openssl_assert(set_group_result == 1)
|
||||
|
||||
set_pubkey_result = backend._lib.EC_KEY_set_public_key(
|
||||
ec_key, self._point_key._backend_point
|
||||
)
|
||||
backend.openssl_assert(set_pubkey_result == 1)
|
||||
|
||||
evp_pkey = backend._ec_cdata_to_evp_pkey(ec_key)
|
||||
return _EllipticCurvePublicKey(backend, ec_key, evp_pkey)
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.__class__.__name__}:{bytes(self).hex()[:16]}"
|
||||
|
||||
|
|
|
@ -1,11 +1,122 @@
|
|||
from contextlib import contextmanager
|
||||
import typing
|
||||
from typing import Tuple
|
||||
|
||||
from cryptography.exceptions import InternalError
|
||||
from cryptography.hazmat.backends.openssl import backend
|
||||
from cryptography.hazmat.backends.openssl.ec import _EllipticCurvePrivateKey, _EllipticCurvePublicKey
|
||||
|
||||
|
||||
@typing.no_type_check
|
||||
def _get_new_BN(set_consttime_flag=True):
|
||||
class Curve:
|
||||
"""
|
||||
Acts as a container to store constant variables such as the OpenSSL
|
||||
curve_nid, the EC_GROUP struct, and the order of the curve.
|
||||
|
||||
Contains a whitelist of supported elliptic curves used in pyUmbral.
|
||||
"""
|
||||
|
||||
_supported_curves = {
|
||||
415: 'secp256r1',
|
||||
714: 'secp256k1',
|
||||
715: 'secp384r1'
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_ec_group_by_curve_nid(nid: int):
|
||||
"""
|
||||
Returns the group of a given curve via its OpenSSL nid. This must be freed
|
||||
after each use otherwise it leaks memory.
|
||||
"""
|
||||
group = backend._lib.EC_GROUP_new_by_curve_name(nid)
|
||||
backend.openssl_assert(group != backend._ffi.NULL)
|
||||
return group
|
||||
|
||||
@staticmethod
|
||||
def _get_ec_order_by_group(ec_group):
|
||||
"""
|
||||
Returns the order of a given curve via its OpenSSL EC_GROUP.
|
||||
"""
|
||||
ec_order = _bn_new()
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
res = backend._lib.EC_GROUP_get_order(ec_group, ec_order, bn_ctx)
|
||||
backend.openssl_assert(res == 1)
|
||||
return ec_order
|
||||
|
||||
@staticmethod
|
||||
def _get_ec_generator_by_group(ec_group):
|
||||
"""
|
||||
Returns the generator point of a given curve via its OpenSSL EC_GROUP.
|
||||
"""
|
||||
generator = backend._lib.EC_GROUP_get0_generator(ec_group)
|
||||
backend.openssl_assert(generator != backend._ffi.NULL)
|
||||
generator = backend._ffi.gc(generator, backend._lib.EC_POINT_clear_free)
|
||||
|
||||
return generator
|
||||
|
||||
@staticmethod
|
||||
def _get_ec_group_degree(ec_group):
|
||||
"""
|
||||
Returns the number of bits needed to represent the order of the finite
|
||||
field upon the curve is based.
|
||||
"""
|
||||
return backend._lib.EC_GROUP_get_degree(ec_group)
|
||||
|
||||
def __init__(self, nid: int):
|
||||
"""
|
||||
Instantiates an OpenSSL curve with the provided curve_nid and derives
|
||||
the proper EC_GROUP struct and order. You can _only_ instantiate curves
|
||||
with supported nids (see `Curve.supported_curves`).
|
||||
"""
|
||||
|
||||
try:
|
||||
self.name = self._supported_curves[nid]
|
||||
except KeyError:
|
||||
raise NotImplementedError("Curve NID {} is not supported.".format(nid))
|
||||
|
||||
self.nid = nid
|
||||
|
||||
self.ec_group = self._get_ec_group_by_curve_nid(self.nid)
|
||||
self.bn_order = self._get_ec_order_by_group(self.ec_group)
|
||||
self.point_generator = self._get_ec_generator_by_group(self.ec_group)
|
||||
|
||||
size_in_bits = self._get_ec_group_degree(self.ec_group)
|
||||
self.field_element_size = (size_in_bits + 7) // 8
|
||||
|
||||
self.scalar_size = _bn_size(self.bn_order)
|
||||
self.order = bn_to_int(self.bn_order)
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, name: str) -> 'Curve':
|
||||
"""
|
||||
Alternate constructor to generate a curve instance by its name.
|
||||
|
||||
Raises NotImplementedError if the name cannot be mapped to a known
|
||||
supported curve NID.
|
||||
"""
|
||||
|
||||
name = name.casefold() # normalize
|
||||
|
||||
for supported_nid, supported_name in cls._supported_curves.items():
|
||||
if name == supported_name:
|
||||
instance = cls(nid=supported_nid)
|
||||
break
|
||||
else:
|
||||
raise NotImplementedError(f"{name} is not supported curve name.")
|
||||
|
||||
return instance
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.nid == other.nid
|
||||
|
||||
def __str__(self):
|
||||
return "<OpenSSL Curve(nid={}, name={})>".format(self.nid, self.name)
|
||||
|
||||
|
||||
#
|
||||
# OpenSSL bignums
|
||||
#
|
||||
|
||||
|
||||
def _bn_new(set_consttime_flag=True):
|
||||
"""
|
||||
Returns a new and initialized OpenSSL BIGNUM.
|
||||
The set_consttime_flag is set to True by default. When this instance of a
|
||||
|
@ -21,71 +132,22 @@ def _get_new_BN(set_consttime_flag=True):
|
|||
return new_bn
|
||||
|
||||
|
||||
@typing.no_type_check
|
||||
def _get_ec_group_by_curve_nid(curve_nid: int):
|
||||
def bn_is_normalized(check_bn, modulus):
|
||||
"""
|
||||
Returns the group of a given curve via its OpenSSL nid. This must be freed
|
||||
after each use otherwise it leaks memory.
|
||||
"""
|
||||
group = backend._lib.EC_GROUP_new_by_curve_name(curve_nid)
|
||||
backend.openssl_assert(group != backend._ffi.NULL)
|
||||
|
||||
return group
|
||||
|
||||
|
||||
@typing.no_type_check
|
||||
def _get_ec_order_by_group(ec_group):
|
||||
"""
|
||||
Returns the order of a given curve via its OpenSSL EC_GROUP.
|
||||
"""
|
||||
ec_order = _get_new_BN()
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
res = backend._lib.EC_GROUP_get_order(ec_group, ec_order, bn_ctx)
|
||||
backend.openssl_assert(res == 1)
|
||||
return ec_order
|
||||
|
||||
|
||||
@typing.no_type_check
|
||||
def _get_ec_generator_by_group(ec_group):
|
||||
"""
|
||||
Returns the generator point of a given curve via its OpenSSL EC_GROUP.
|
||||
"""
|
||||
generator = backend._lib.EC_GROUP_get0_generator(ec_group)
|
||||
backend.openssl_assert(generator != backend._ffi.NULL)
|
||||
generator = backend._ffi.gc(generator, backend._lib.EC_POINT_clear_free)
|
||||
|
||||
return generator
|
||||
|
||||
|
||||
@typing.no_type_check
|
||||
def _get_ec_group_degree(ec_group):
|
||||
"""
|
||||
Returns the number of bits needed to represent the order of the finite
|
||||
field upon the curve is based.
|
||||
"""
|
||||
return backend._lib.EC_GROUP_get_degree(ec_group)
|
||||
|
||||
|
||||
@typing.no_type_check
|
||||
def _bn_is_on_curve(check_bn, curve: 'Curve'):
|
||||
"""
|
||||
Checks if a given OpenSSL BIGNUM is within the provided curve's order.
|
||||
Returns True if the provided BN is on the curve, that is in the range `[0, curve_order)`.
|
||||
Returns ``True`` if ``check_bn`` is in ``[0, modulus)``, ``False`` otherwise.
|
||||
"""
|
||||
zero = backend._int_to_bn(0)
|
||||
zero = backend._ffi.gc(zero, backend._lib.BN_clear_free)
|
||||
|
||||
check_sign = backend._lib.BN_cmp(check_bn, zero)
|
||||
range_check = backend._lib.BN_cmp(check_bn, curve.order)
|
||||
range_check = backend._lib.BN_cmp(check_bn, modulus)
|
||||
return (check_sign == 1 or check_sign == 0) and range_check == -1
|
||||
|
||||
|
||||
@typing.no_type_check
|
||||
def _int_to_bn(py_int: int, curve: 'Curve'=None, set_consttime_flag=True):
|
||||
def bn_from_int(py_int: int, modulus=None, set_consttime_flag=True):
|
||||
"""
|
||||
Converts the given Python int to an OpenSSL BIGNUM. If a curve is
|
||||
provided, it will check if the Python integer is within the order of that
|
||||
curve. If it's not within the order, it will raise a ValueError.
|
||||
Converts the given Python int to an OpenSSL BIGNUM. If ``modulus`` is
|
||||
provided, it will check if the Python integer is within ``[0, modulus)``.
|
||||
|
||||
If set_consttime_flag is set to True, OpenSSL will use constant time
|
||||
operations when using this CurveBN.
|
||||
|
@ -93,109 +155,290 @@ def _int_to_bn(py_int: int, curve: 'Curve'=None, set_consttime_flag=True):
|
|||
conv_bn = backend._int_to_bn(py_int)
|
||||
conv_bn = backend._ffi.gc(conv_bn, backend._lib.BN_clear_free)
|
||||
|
||||
if curve:
|
||||
on_curve = _bn_is_on_curve(conv_bn, curve)
|
||||
if not on_curve:
|
||||
raise ValueError("The Python integer given is not on the provided curve.")
|
||||
if modulus and not bn_is_normalized(conv_bn, modulus):
|
||||
raise ValueError("The Python integer given is not under the provided modulus.")
|
||||
|
||||
if set_consttime_flag:
|
||||
backend._lib.BN_set_flags(conv_bn, backend._lib.BN_FLG_CONSTTIME)
|
||||
return conv_bn
|
||||
|
||||
@typing.no_type_check
|
||||
def _bytes_to_bn(bytes_seq: bytes, set_consttime_flag=True):
|
||||
|
||||
def bn_from_bytes(bytes_seq: bytes, set_consttime_flag=True, modulus=None):
|
||||
"""
|
||||
Converts the given byte sequence to an OpenSSL BIGNUM.
|
||||
If set_consttime_flag is set to True, OpenSSL will use constant time
|
||||
operations when using this BIGNUM.
|
||||
"""
|
||||
bn = _get_new_BN(set_consttime_flag)
|
||||
bn = _bn_new(set_consttime_flag)
|
||||
backend._lib.BN_bin2bn(bytes_seq, len(bytes_seq), bn)
|
||||
backend.openssl_assert(bn != backend._ffi.NULL)
|
||||
|
||||
if modulus:
|
||||
bignum =_bn_new()
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
res = backend._lib.BN_mod(bignum, bn, modulus, bn_ctx)
|
||||
backend.openssl_assert(res == 1)
|
||||
|
||||
return bn
|
||||
|
||||
@typing.no_type_check
|
||||
def _bn_to_bytes(bignum, length : int = None):
|
||||
|
||||
def bn_to_bytes(bn, length: int):
|
||||
"""
|
||||
Converts the given OpenSSL BIGNUM into a Python bytes sequence.
|
||||
If length is given, the return bytes will have such length.
|
||||
If the BIGNUM doesn't fit, it raises a ValueError.
|
||||
"""
|
||||
|
||||
if bignum is None or bignum == backend._ffi.NULL:
|
||||
raise ValueError("Input BIGNUM must have a value")
|
||||
|
||||
bn_num_bytes = backend._lib.BN_num_bytes(bignum)
|
||||
if length is None:
|
||||
length = bn_num_bytes
|
||||
elif bn_num_bytes > length:
|
||||
raise ValueError("Input BIGNUM doesn't fit in {} B".format(length))
|
||||
# Sanity check, CurveScalar ensures it won't happen.
|
||||
bn_num_bytes = backend._lib.BN_num_bytes(bn)
|
||||
assert bn_num_bytes <= length, f"Input BIGNUM doesn't fit in {length} B"
|
||||
|
||||
bin_ptr = backend._ffi.new("unsigned char []", length)
|
||||
bin_len = backend._lib.BN_bn2bin(bignum, bin_ptr)
|
||||
bin_len = backend._lib.BN_bn2bin(bn, bin_ptr)
|
||||
return bytes.rjust(backend._ffi.buffer(bin_ptr, bin_len)[:], length, b'\0')
|
||||
|
||||
|
||||
@typing.no_type_check
|
||||
def _get_new_EC_POINT(curve: 'Curve'):
|
||||
def bn_random_nonzero(modulus):
|
||||
|
||||
one = backend._lib.BN_value_one()
|
||||
|
||||
# TODO: in most cases, we want this number to be secret.
|
||||
# OpenSSL 1.1.1 has `BN_priv_rand_range()`, but it is not
|
||||
# currently exported by `cryptography`.
|
||||
# Use when available.
|
||||
|
||||
# Calculate `modulus - 1`
|
||||
modulus_minus_1 = _bn_new()
|
||||
res = backend._lib.BN_sub(modulus_minus_1, modulus, one)
|
||||
backend.openssl_assert(res == 1)
|
||||
|
||||
# Get a random in range `[0, modulus - 1)`
|
||||
new_rand_bn = _bn_new()
|
||||
res = backend._lib.BN_rand_range(new_rand_bn, modulus_minus_1)
|
||||
backend.openssl_assert(res == 1)
|
||||
|
||||
# Turn it into a random in range `[1, modulus)`
|
||||
op_sum = _bn_new()
|
||||
res = backend._lib.BN_add(op_sum, new_rand_bn, one)
|
||||
backend.openssl_assert(res == 1)
|
||||
|
||||
return op_sum
|
||||
|
||||
|
||||
def _bn_size(bn):
|
||||
return backend._lib.BN_num_bytes(bn)
|
||||
|
||||
|
||||
def bn_to_int(bn):
|
||||
return backend._bn_to_int(bn)
|
||||
|
||||
|
||||
def bn_cmp(bn1, bn2):
|
||||
# -1 less than, 0 is equal to, 1 is greater than
|
||||
return backend._lib.BN_cmp(bn1, bn2)
|
||||
|
||||
|
||||
def bn_one():
|
||||
return backend._lib.BN_value_one()
|
||||
|
||||
|
||||
def bn_is_zero(bn):
|
||||
# No special function exported in the current backend, so this will have to do
|
||||
return bn_cmp(bn, bn_from_int(0)) == 0
|
||||
|
||||
|
||||
def bn_invert(bn, modulus):
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
inv = backend._lib.BN_mod_inverse(backend._ffi.NULL, bn, modulus, bn_ctx)
|
||||
backend.openssl_assert(inv != backend._ffi.NULL)
|
||||
inv = backend._ffi.gc(inv, backend._lib.BN_clear_free)
|
||||
return inv
|
||||
|
||||
|
||||
def bn_sub(bn1, bn2, modulus):
|
||||
diff = _bn_new()
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
res = backend._lib.BN_mod_sub(diff, bn1, bn2, modulus, bn_ctx)
|
||||
backend.openssl_assert(res == 1)
|
||||
return diff
|
||||
|
||||
|
||||
def bn_add(bn1, bn2, modulus):
|
||||
op_sum = _bn_new()
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
res = backend._lib.BN_mod_add(op_sum, bn1, bn2, modulus, bn_ctx)
|
||||
backend.openssl_assert(res == 1)
|
||||
return op_sum
|
||||
|
||||
|
||||
def bn_mul(bn1, bn2, modulus):
|
||||
product = _bn_new()
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
res = backend._lib.BN_mod_mul(product, bn1, bn2, modulus, bn_ctx)
|
||||
backend.openssl_assert(res == 1)
|
||||
return product
|
||||
|
||||
|
||||
def bn_to_privkey(curve: Curve, bn):
|
||||
|
||||
ec_key = backend._lib.EC_KEY_new()
|
||||
backend.openssl_assert(ec_key != backend._ffi.NULL)
|
||||
ec_key = backend._ffi.gc(ec_key, backend._lib.EC_KEY_free)
|
||||
|
||||
set_group_result = backend._lib.EC_KEY_set_group(ec_key, curve.ec_group)
|
||||
backend.openssl_assert(set_group_result == 1)
|
||||
|
||||
set_privkey_result = backend._lib.EC_KEY_set_private_key(ec_key, bn)
|
||||
backend.openssl_assert(set_privkey_result == 1)
|
||||
|
||||
evp_pkey = backend._ec_cdata_to_evp_pkey(ec_key)
|
||||
return _EllipticCurvePrivateKey(backend, ec_key, evp_pkey)
|
||||
|
||||
|
||||
#
|
||||
# OpenSSL EC points
|
||||
#
|
||||
|
||||
|
||||
def _point_new(ec_group):
|
||||
"""
|
||||
Returns a new and initialized OpenSSL EC_POINT given the group of a curve.
|
||||
If __curve_nid is provided, it retrieves the group from the curve provided.
|
||||
"""
|
||||
new_point = backend._lib.EC_POINT_new(curve.ec_group)
|
||||
new_point = backend._lib.EC_POINT_new(ec_group)
|
||||
backend.openssl_assert(new_point != backend._ffi.NULL)
|
||||
new_point = backend._ffi.gc(new_point, backend._lib.EC_POINT_clear_free)
|
||||
|
||||
return new_point
|
||||
|
||||
|
||||
@typing.no_type_check
|
||||
def _get_EC_POINT_via_affine(affine_x, affine_y, curve: 'Curve'):
|
||||
def point_from_affine_coords(curve: Curve, affine_x: int, affine_y: int):
|
||||
"""
|
||||
Returns an EC_POINT given the group of a curve and the affine coordinates
|
||||
provided.
|
||||
"""
|
||||
new_point = _get_new_EC_POINT(curve)
|
||||
bn_affine_x = bn_from_int(affine_x)
|
||||
bn_affine_y = bn_from_int(affine_y)
|
||||
|
||||
new_point = _point_new(curve.ec_group)
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
res = backend._lib.EC_POINT_set_affine_coordinates_GFp(
|
||||
curve.ec_group, new_point, affine_x, affine_y, bn_ctx
|
||||
curve.ec_group, new_point, bn_affine_x, bn_affine_y, bn_ctx
|
||||
)
|
||||
backend.openssl_assert(res == 1)
|
||||
return new_point
|
||||
|
||||
|
||||
@typing.no_type_check
|
||||
def _get_affine_coords_via_EC_POINT(ec_point, curve: 'Curve'):
|
||||
def point_to_affine_coords(curve: Curve, point) -> Tuple[int, int]:
|
||||
"""
|
||||
Returns the affine coordinates of a given point on the provided ec_group.
|
||||
"""
|
||||
affine_x = _get_new_BN()
|
||||
affine_y = _get_new_BN()
|
||||
affine_x = _bn_new()
|
||||
affine_y = _bn_new()
|
||||
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
res = backend._lib.EC_POINT_get_affine_coordinates_GFp(
|
||||
curve.ec_group, ec_point, affine_x, affine_y, bn_ctx
|
||||
curve.ec_group, point, affine_x, affine_y, bn_ctx
|
||||
)
|
||||
backend.openssl_assert(res == 1)
|
||||
return (affine_x, affine_y)
|
||||
|
||||
return bn_to_int(affine_x), bn_to_int(affine_y)
|
||||
|
||||
|
||||
@typing.no_type_check
|
||||
@contextmanager
|
||||
def _tmp_bn_mont_ctx(modulus):
|
||||
"""
|
||||
Initializes and returns a BN_MONT_CTX for Montgomery ops.
|
||||
Requires a modulus to place in the Montgomery structure.
|
||||
"""
|
||||
bn_mont_ctx = backend._lib.BN_MONT_CTX_new()
|
||||
backend.openssl_assert(bn_mont_ctx != backend._ffi.NULL)
|
||||
# Don't set the garbage collector. Only free it when the context is done
|
||||
# or else you'll get a null pointer error.
|
||||
class ErrorInvalidCompressedPoint(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ErrorInvalidPointEncoding(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def point_from_bytes(curve: Curve, data):
|
||||
point = _point_new(curve.ec_group)
|
||||
try:
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
res = backend._lib.BN_MONT_CTX_set(bn_mont_ctx, modulus, bn_ctx)
|
||||
res = backend._lib.EC_POINT_oct2point(curve.ec_group, point, data, len(data), bn_ctx);
|
||||
backend.openssl_assert(res == 1)
|
||||
yield bn_mont_ctx
|
||||
finally:
|
||||
backend._lib.BN_MONT_CTX_free(bn_mont_ctx)
|
||||
except InternalError as e:
|
||||
# We want to catch specific InternalExceptions.
|
||||
# https://github.com/openssl/openssl/blob/master/include/openssl/ecerr.h
|
||||
# There is also EC_R_POINT_IS_NOT_ON_CURVE (code 107),
|
||||
# but somehow it is never triggered during deserialization.
|
||||
if e.err_code[0].reason == 110: # EC_R_INVALID_COMPRESSED_POINT
|
||||
raise ErrorInvalidCompressedPoint
|
||||
elif e.err_code[0].reason == 102: # EC_R_INVALID_ENCODING
|
||||
raise ErrorInvalidPointEncoding
|
||||
else:
|
||||
# Any other exception, we raise it.
|
||||
# (although at the moment I'm not sure what should one do to cause it)
|
||||
raise e # pragma: no cover
|
||||
return point
|
||||
|
||||
|
||||
def point_to_bytes_compressed(curve: Curve, point):
|
||||
point_conversion_form = backend._lib.POINT_CONVERSION_COMPRESSED
|
||||
|
||||
size = curve.field_element_size + 1 # compressed point size
|
||||
|
||||
bin_ptr = backend._ffi.new("unsigned char[]", size)
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
bin_len = backend._lib.EC_POINT_point2oct(
|
||||
curve.ec_group, point, point_conversion_form,
|
||||
bin_ptr, size, bn_ctx
|
||||
)
|
||||
backend.openssl_assert(bin_len != 0)
|
||||
|
||||
return bytes(backend._ffi.buffer(bin_ptr, bin_len)[:])
|
||||
|
||||
|
||||
def point_eq(curve: Curve, point1, point2):
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
is_equal = backend._lib.EC_POINT_cmp(curve.ec_group, point1, point2, bn_ctx)
|
||||
backend.openssl_assert(is_equal != -1)
|
||||
|
||||
# 1 is not-equal, 0 is equal, -1 is error
|
||||
return is_equal == 0
|
||||
|
||||
|
||||
def point_mul_bn(curve: Curve, point, bn):
|
||||
prod = _point_new(curve.ec_group)
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
res = backend._lib.EC_POINT_mul(curve.ec_group, prod, backend._ffi.NULL, point, bn, bn_ctx)
|
||||
backend.openssl_assert(res == 1)
|
||||
return prod
|
||||
|
||||
|
||||
def point_add(curve: Curve, point1, point2):
|
||||
op_sum = _point_new(curve.ec_group)
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
res = backend._lib.EC_POINT_add(curve.ec_group, op_sum, point1, point2, bn_ctx)
|
||||
backend.openssl_assert(res == 1)
|
||||
return op_sum
|
||||
|
||||
|
||||
def point_neg(curve: Curve, point):
|
||||
inv = backend._lib.EC_POINT_dup(point, curve.ec_group)
|
||||
backend.openssl_assert(inv != backend._ffi.NULL)
|
||||
inv = backend._ffi.gc(inv, backend._lib.EC_POINT_clear_free)
|
||||
|
||||
with backend._tmp_bn_ctx() as bn_ctx:
|
||||
res = backend._lib.EC_POINT_invert(curve.ec_group, inv, bn_ctx)
|
||||
backend.openssl_assert(res == 1)
|
||||
|
||||
return inv
|
||||
|
||||
|
||||
def point_to_pubkey(curve: Curve, point):
|
||||
|
||||
ec_key = backend._lib.EC_KEY_new()
|
||||
backend.openssl_assert(ec_key != backend._ffi.NULL)
|
||||
ec_key = backend._ffi.gc(ec_key, backend._lib.EC_KEY_free)
|
||||
|
||||
set_group_result = backend._lib.EC_KEY_set_group(ec_key, curve.ec_group)
|
||||
backend.openssl_assert(set_group_result == 1)
|
||||
|
||||
set_pubkey_result = backend._lib.EC_KEY_set_public_key(ec_key, point)
|
||||
backend.openssl_assert(set_pubkey_result == 1)
|
||||
|
||||
evp_pkey = backend._ec_cdata_to_evp_pkey(ec_key)
|
||||
return _EllipticCurvePublicKey(backend, ec_key, evp_pkey)
|
||||
|
|
Loading…
Reference in New Issue