mirror of https://github.com/nucypher/pyUmbral.git
214 lines
7.3 KiB
Python
214 lines
7.3 KiB
Python
"""
|
|
Copyright (C) 2018 NuCypher
|
|
|
|
This file is part of pyUmbral.
|
|
|
|
pyUmbral is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
pyUmbral is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with pyUmbral. If not, see <https://www.gnu.org/licenses/>.
|
|
"""
|
|
|
|
from typing import Optional, Tuple
|
|
|
|
from cryptography.hazmat.backends.openssl import backend
|
|
|
|
from umbral import openssl
|
|
from umbral.config import default_curve
|
|
from umbral.curve import Curve
|
|
from umbral.curvebn import CurveBN
|
|
|
|
|
|
class Point(object):
|
|
"""
|
|
Represents an OpenSSL EC_POINT except more Pythonic
|
|
"""
|
|
|
|
def __init__(self, ec_point, curve: Curve) -> None:
|
|
self.ec_point = ec_point
|
|
self.curve = curve
|
|
|
|
@classmethod
|
|
def expected_bytes_length(cls, curve: Optional[Curve] = None,
|
|
is_compressed: bool = True):
|
|
"""
|
|
Returns the size (in bytes) of a Point given a curve.
|
|
If no curve is provided, it uses the default curve.
|
|
By default, it assumes compressed representation (is_compressed = True).
|
|
"""
|
|
curve = curve if curve is not None else default_curve()
|
|
|
|
coord_size = curve.field_order_size_in_bytes
|
|
|
|
if is_compressed:
|
|
return 1 + coord_size
|
|
else:
|
|
return 1 + 2 * coord_size
|
|
|
|
@classmethod
|
|
def gen_rand(cls, curve: Optional[Curve] = None) -> 'Point':
|
|
"""
|
|
Returns a Point object with a cryptographically secure EC_POINT based
|
|
on the provided curve.
|
|
"""
|
|
curve = curve if curve is not None else default_curve()
|
|
|
|
rand_point = openssl._get_new_EC_POINT(curve)
|
|
rand_bn = CurveBN.gen_rand(curve).bignum
|
|
|
|
with backend._tmp_bn_ctx() as bn_ctx:
|
|
res = backend._lib.EC_POINT_mul(
|
|
curve.ec_group, rand_point, backend._ffi.NULL, curve.generator,
|
|
rand_bn, bn_ctx
|
|
)
|
|
backend.openssl_assert(res == 1)
|
|
|
|
return cls(rand_point, curve)
|
|
|
|
@classmethod
|
|
def from_affine(cls, coords: Tuple[int, int], curve: Optional[Curve] = None) -> 'Point':
|
|
"""
|
|
Returns a Point object from the given affine coordinates in a tuple in
|
|
the format of (x, y) and a given curve.
|
|
"""
|
|
curve = curve if curve is not None else default_curve()
|
|
|
|
affine_x, affine_y = coords
|
|
if type(affine_x) == int:
|
|
affine_x = openssl._int_to_bn(affine_x, curve=None)
|
|
|
|
if type(affine_y) == int:
|
|
affine_y = openssl._int_to_bn(affine_y, curve=None)
|
|
|
|
ec_point = openssl._get_EC_POINT_via_affine(affine_x, affine_y, curve)
|
|
return cls(ec_point, curve)
|
|
|
|
def to_affine(self):
|
|
"""
|
|
Returns a tuple of Python ints in the format of (x, y) that represents
|
|
the point in the curve.
|
|
"""
|
|
affine_x, affine_y = openssl._get_affine_coords_via_EC_POINT(
|
|
self.ec_point, self.curve)
|
|
return (backend._bn_to_int(affine_x), backend._bn_to_int(affine_y))
|
|
|
|
@classmethod
|
|
def from_bytes(cls, data: bytes, curve: Optional[Curve] = None) -> 'Point':
|
|
"""
|
|
Returns a Point object from the given byte data on the curve provided.
|
|
"""
|
|
curve = curve if curve is not None else default_curve()
|
|
|
|
point = openssl._get_new_EC_POINT(curve)
|
|
with backend._tmp_bn_ctx() as bn_ctx:
|
|
res = backend._lib.EC_POINT_oct2point(
|
|
curve.ec_group, point, data, len(data), bn_ctx);
|
|
backend.openssl_assert(res == 1)
|
|
|
|
return cls(point, curve)
|
|
|
|
def to_bytes(self, is_compressed: bool=True) -> bytes:
|
|
"""
|
|
Returns the Point serialized as bytes. It will return a compressed form
|
|
if is_compressed is set to True.
|
|
"""
|
|
length = self.expected_bytes_length(self.curve, is_compressed)
|
|
|
|
if is_compressed:
|
|
point_conversion_form = backend._lib.POINT_CONVERSION_COMPRESSED
|
|
else:
|
|
point_conversion_form = backend._lib.POINT_CONVERSION_UNCOMPRESSED
|
|
|
|
bin_ptr = backend._ffi.new("unsigned char[]", length)
|
|
with backend._tmp_bn_ctx() as bn_ctx:
|
|
bin_len = backend._lib.EC_POINT_point2oct(
|
|
self.curve.ec_group, self.ec_point, point_conversion_form,
|
|
bin_ptr, length, bn_ctx
|
|
)
|
|
backend.openssl_assert(bin_len != 0)
|
|
|
|
return bytes(backend._ffi.buffer(bin_ptr, bin_len)[:])
|
|
|
|
@classmethod
|
|
def get_generator_from_curve(cls, curve: Optional[Curve] = None) -> 'Point':
|
|
"""
|
|
Returns the generator Point from the given curve as a Point object.
|
|
"""
|
|
curve = curve if curve is not None else default_curve()
|
|
return cls(curve.generator, curve)
|
|
|
|
def __eq__(self, other):
|
|
"""
|
|
Compares two EC_POINTS for equality.
|
|
"""
|
|
with backend._tmp_bn_ctx() as bn_ctx:
|
|
is_equal = backend._lib.EC_POINT_cmp(
|
|
self.curve.ec_group, self.ec_point, other.ec_point, bn_ctx
|
|
)
|
|
backend.openssl_assert(is_equal != -1)
|
|
|
|
# 1 is not-equal, 0 is equal, -1 is error
|
|
return not bool(is_equal)
|
|
|
|
def __mul__(self, other: CurveBN) -> 'Point':
|
|
"""
|
|
Performs an EC_POINT_mul on an EC_POINT and a BIGNUM.
|
|
"""
|
|
# TODO: Check that both points use the same curve.
|
|
prod = openssl._get_new_EC_POINT(self.curve)
|
|
with backend._tmp_bn_ctx() as bn_ctx:
|
|
res = backend._lib.EC_POINT_mul(
|
|
self.curve.ec_group, prod, backend._ffi.NULL,
|
|
self.ec_point, other.bignum, bn_ctx
|
|
)
|
|
backend.openssl_assert(res == 1)
|
|
|
|
return Point(prod, self.curve)
|
|
|
|
__rmul__ = __mul__
|
|
|
|
def __add__(self, other) -> 'Point':
|
|
"""
|
|
Performs an EC_POINT_add on two EC_POINTS.
|
|
"""
|
|
op_sum = openssl._get_new_EC_POINT(self.curve)
|
|
with backend._tmp_bn_ctx() as bn_ctx:
|
|
res = backend._lib.EC_POINT_add(
|
|
self.curve.ec_group, op_sum, self.ec_point, other.ec_point, bn_ctx
|
|
)
|
|
backend.openssl_assert(res == 1)
|
|
return Point(op_sum, self.curve)
|
|
|
|
def __sub__(self, other):
|
|
"""
|
|
Performs subtraction by adding the inverse of the `other` to the point.
|
|
"""
|
|
return (self + (-other))
|
|
|
|
def __neg__(self) -> 'Point':
|
|
"""
|
|
Computes the additive inverse of a Point, by performing an
|
|
EC_POINT_invert on itself.
|
|
"""
|
|
inv = backend._lib.EC_POINT_dup(self.ec_point, self.curve.ec_group)
|
|
backend.openssl_assert(inv != backend._ffi.NULL)
|
|
inv = backend._ffi.gc(inv, backend._lib.EC_POINT_clear_free)
|
|
|
|
with backend._tmp_bn_ctx() as bn_ctx:
|
|
res = backend._lib.EC_POINT_invert(
|
|
self.curve.ec_group, inv, bn_ctx
|
|
)
|
|
backend.openssl_assert(res == 1)
|
|
return Point(inv, self.curve)
|
|
|
|
def __bytes__(self) -> bytes:
|
|
return self.to_bytes()
|