diff --git a/tests/test_primitives/conftest.py b/tests/test_primitives/conftest.py index 713fb58..4de3e9f 100644 --- a/tests/test_primitives/conftest.py +++ b/tests/test_primitives/conftest.py @@ -48,7 +48,7 @@ def mock_openssl(mocker, random_ec_point1: Point, random_ec_curvebn1: CurveBN, r check_point_ctypes(ec_point, other_point) assert 'BN_CTX' in str(context) assert 'EC_GROUP' in str(group) - assert random_ec_point1.group == group + assert random_ec_point1.curve.ec_group == group assert not bool(actual_backend['EC_POINT_cmp'](group, random_ec_point1.ec_point, ec_point, context)) result = actual_backend['EC_POINT_cmp'](group, random_ec_point1.ec_point, other_point, context) assert not bool(result) diff --git a/tests/test_primitives/test_point/test_point_arithmetic.py b/tests/test_primitives/test_point/test_point_arithmetic.py index 62646ae..8c1bee2 100644 --- a/tests/test_primitives/test_point/test_point_arithmetic.py +++ b/tests/test_primitives/test_point/test_point_arithmetic.py @@ -30,8 +30,7 @@ def test_point_curve_multiplication_regression(): # Make sure we have instantiated a new, unequal point in the same curve and group assert isinstance(product_with_star_operator, Point), "Point.__mul__ did not return a point instance" assert k256_point != product_with_star_operator - assert k256_point.curve_nid == product_with_star_operator.curve_nid - assert k256_point.group == product_with_star_operator.group + assert k256_point.curve == product_with_star_operator.curve product_bytes = b'\x03\xc9\xda\xa2\x88\xe2\xa0+\xb1N\xb6\xe6\x1c\xa5(\xe6\xe0p\xf6\xf4\xa9\xfc\xb1\xfaUV\xd3\xb3\x0e4\x94\xbe\x12' product_point = Point.from_bytes(product_bytes) diff --git a/tests/test_primitives/test_point/test_point_serializers.py b/tests/test_primitives/test_point/test_point_serializers.py index fa80d43..1e67360 100644 --- a/tests/test_primitives/test_point/test_point_serializers.py +++ b/tests/test_primitives/test_point/test_point_serializers.py @@ -36,15 +36,9 @@ def test_generate_random_points(): @pytest.mark.parametrize("curve, nid, point_bytes", generate_test_points_bytes()) def test_bytes_serializers(point_bytes, nid, curve): - - point_with_nid = Point.from_bytes(point_bytes, curve=nid) # from nid - assert isinstance(point_with_nid, Point) - - point_with_curve = Point.from_bytes(point_bytes, curve=curve) # from curve + point_with_curve = Point.from_bytes(point_bytes, curve=curve) # from curve assert isinstance(point_with_curve, Point) - assert point_with_nid == point_with_curve - the_same_point_bytes = point_with_curve.to_bytes() assert point_bytes == the_same_point_bytes @@ -67,9 +61,7 @@ def test_bytes_serializers(point_bytes, nid, curve): @pytest.mark.parametrize("curve, nid, point_affine", generate_test_points_affine()) def test_affine(point_affine, nid, curve): - point = Point.from_affine(point_affine, curve=nid) # from nid - the_same_point = Point.from_affine(point_affine, curve=curve) # from curve instance - assert point == the_same_point + point = Point.from_affine(point_affine, curve=curve) # from curve assert isinstance(point, Point) point_affine2 = point.to_affine() assert point_affine == point_affine2 diff --git a/tests/test_simple_api.py b/tests/test_simple_api.py index c6ff3e0..348044e 100644 --- a/tests/test_simple_api.py +++ b/tests/test_simple_api.py @@ -4,22 +4,23 @@ from cryptography.hazmat.primitives.asymmetric import ec from umbral import pre from umbral.fragments import KFrag, CapsuleFrag +from umbral.curve import SECP384R1, SECP256R1 from umbral.config import default_curve from umbral.params import UmbralParameters from umbral.signing import Signer from umbral.keys import UmbralPrivateKey, UmbralPublicKey from .conftest import parameters, wrong_parameters + secp_curves = [ - ec.SECP384R1, - ec.SECP192R1 + SECP384R1(), + SECP256R1() ] @pytest.mark.parametrize("N, M", parameters) def test_simple_api(N, M, curve=default_curve()): """Manually injects umbralparameters for multi-curve testing.""" - params = UmbralParameters(curve=curve) delegating_privkey = UmbralPrivateKey.gen_key(params=params) diff --git a/umbral/curvebn.py b/umbral/curvebn.py index 8d3c6a0..7fcec25 100644 --- a/umbral/curvebn.py +++ b/umbral/curvebn.py @@ -241,7 +241,7 @@ class CurveBN(object): """ if type(other) == int: other = openssl._int_to_bn(other) - other = CurveBN(other, None, None, None) + other = CurveBN(other, self.curve) rem = openssl._get_new_BN() with backend._tmp_bn_ctx() as bn_ctx: diff --git a/umbral/point.py b/umbral/point.py index bd1faef..293d48c 100644 --- a/umbral/point.py +++ b/umbral/point.py @@ -67,9 +67,9 @@ class Point(object): group = openssl._get_ec_group_by_curve_nid(curve.curve_nid) ec_point = openssl._get_EC_POINT_via_affine(affine_x, affine_y, - ec_group=self.curve.ec_group) + ec_group=curve.ec_group) - return cls(ec_point, curve_nid, group) + return cls(ec_point, curve) def to_affine(self): """