"""
This file is part of nucypher.
nucypher is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
nucypher is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see .
"""
from dataclasses import dataclass
from http import HTTPStatus
from typing import List, Optional, Tuple, Type
from unittest.mock import Mock, patch
import pytest
from marshmallow import fields
from web3.providers import BaseProvider
from nucypher.policy.conditions.exceptions import (
ConditionEvaluationFailed,
ContextVariableVerificationFailed,
InvalidCondition,
InvalidConditionLingo,
InvalidContextVariableData,
NoConnectionToChain,
RequiredContextVariable,
ReturnValueEvaluationError,
)
from nucypher.policy.conditions.lingo import ConditionLingo
from nucypher.policy.conditions.utils import (
CamelCaseSchema,
ConditionEvalError,
ConditionProviderManager,
camel_case_to_snake,
check_and_convert_big_int_string_to_int,
evaluate_condition_lingo,
to_camelcase,
)
from tests.constants import INT256_MIN, UINT256_MAX
FAILURE_CASE_EXCEPTION_CODE_MATCHING = [
# (exception, constructor parameters, expected status code)
(ReturnValueEvaluationError, None, HTTPStatus.BAD_REQUEST),
(InvalidConditionLingo, None, HTTPStatus.BAD_REQUEST),
(InvalidCondition, None, HTTPStatus.BAD_REQUEST),
(RequiredContextVariable, None, HTTPStatus.BAD_REQUEST),
(InvalidContextVariableData, None, HTTPStatus.BAD_REQUEST),
(ContextVariableVerificationFailed, None, HTTPStatus.FORBIDDEN),
(NoConnectionToChain, [1], HTTPStatus.NOT_IMPLEMENTED),
(ConditionEvaluationFailed, None, HTTPStatus.BAD_REQUEST),
(Exception, None, HTTPStatus.INTERNAL_SERVER_ERROR),
]
@pytest.mark.parametrize("failure_case", FAILURE_CASE_EXCEPTION_CODE_MATCHING)
def test_evaluate_condition_exception_cases(
failure_case: Tuple[Type[Exception], Optional[List], int]
):
exception_class, exception_constructor_params, expected_status_code = failure_case
exception_constructor_params = exception_constructor_params or []
condition_lingo = Mock()
condition_lingo.eval.side_effect = exception_class(*exception_constructor_params)
with patch(
"nucypher.policy.conditions.lingo.ConditionLingo.from_dict"
) as mocked_from_dict:
mocked_from_dict.return_value = condition_lingo
with pytest.raises(ConditionEvalError) as eval_error:
evaluate_condition_lingo(
condition_lingo=condition_lingo
) # provider and context default to empty dicts
assert eval_error.value.status_code == expected_status_code
def test_evaluate_condition_invalid_lingo():
with pytest.raises(ConditionEvalError) as eval_error:
evaluate_condition_lingo(
condition_lingo={
"version": ConditionLingo.VERSION,
"condition": {"dont_mind_me": "nothing_to_see_here"},
}
) # provider and context default to empty dicts
assert "Invalid condition grammar" in eval_error.value.message
assert eval_error.value.status_code == HTTPStatus.BAD_REQUEST
def test_evaluate_condition_eval_returns_false():
condition_lingo = Mock()
condition_lingo.eval.return_value = False
with patch(
"nucypher.policy.conditions.lingo.ConditionLingo.from_dict"
) as mocked_from_dict:
mocked_from_dict.return_value = condition_lingo
with pytest.raises(ConditionEvalError) as eval_error:
evaluate_condition_lingo(
condition_lingo=condition_lingo,
providers=ConditionProviderManager(
{1: Mock(spec=BaseProvider)}
), # fake provider
context={"key": "value"}, # fake context
)
assert eval_error.value.status_code == HTTPStatus.FORBIDDEN
def test_evaluate_condition_eval_returns_true():
condition_lingo = Mock()
condition_lingo.eval.return_value = True
with patch(
"nucypher.policy.conditions.lingo.ConditionLingo.from_dict"
) as mocked_from_dict:
mocked_from_dict.return_value = condition_lingo
evaluate_condition_lingo(
condition_lingo=condition_lingo,
providers=ConditionProviderManager(
{
1: Mock(spec=BaseProvider),
2: Mock(spec=BaseProvider),
}
),
context={
"key1": "value1",
"key2": "value2",
}, # multiple values in fake context
)
@pytest.mark.parametrize(
"test_case",
(
("nounderscores", "nounderscores"),
("one_underscore", "oneUnderscore"),
("two_under_scores", "twoUnderScores"),
),
)
def test_to_from_camel_case(test_case: Tuple[str, str]):
# test to_camelcase()
snake_case, camel_case = test_case
result = to_camelcase(snake_case)
assert result == camel_case
# test camel_case_to_snake()
result = camel_case_to_snake(camel_case)
assert result == snake_case
def test_camel_case_schema():
# test CamelCaseSchema
@dataclass
class Function:
field_name_with_underscores: str
class FunctionSchema(CamelCaseSchema):
field_name_with_underscores = fields.Str()
value = "field_name_value"
function = Function(field_name_with_underscores=value)
schema = FunctionSchema()
output = schema.dump(function)
assert output == {"fieldNameWithUnderscores": f"{value}"}
reloaded_function = schema.load(output)
assert reloaded_function == {"field_name_with_underscores": f"{value}"}
def test_condition_provider_manager(mocker):
# no condition to chain
with pytest.raises(NoConnectionToChain, match="No connection to chain ID"):
manager = ConditionProviderManager(
providers={2: [mocker.Mock(spec=BaseProvider)]}
)
_ = list(manager.web3_endpoints(chain_id=1))
# invalid provider chain
manager = ConditionProviderManager(providers={2: [mocker.Mock(spec=BaseProvider)]})
w3 = mocker.Mock()
w3.eth.chain_id = (
1 # make w3 instance created from provider have incorrect chain id
)
with patch.object(manager, "_configure_w3", return_value=w3):
with pytest.raises(
NoConnectionToChain, match="Problematic provider endpoints for chain ID"
):
_ = list(manager.web3_endpoints(chain_id=2))
# valid provider chain
manager = ConditionProviderManager(providers={2: [mocker.Mock(spec=BaseProvider)]})
with patch.object(manager, "_check_chain_id", return_value=None):
assert len(list(manager.web3_endpoints(chain_id=2))) == 1
# multiple providers
manager = ConditionProviderManager(
providers={2: [mocker.Mock(spec=BaseProvider), mocker.Mock(spec=BaseProvider)]}
)
with patch.object(manager, "_check_chain_id", return_value=None):
w3_instances = list(manager.web3_endpoints(chain_id=2))
assert len(w3_instances) == 2
for w3_instance in w3_instances:
assert w3_instance # actual object returned
assert w3_instance.middleware_onion.get("poa") # poa middleware injected
# specific w3 instances
w3_1 = mocker.Mock()
w3_1.eth.chain_id = 2
w3_2 = mocker.Mock()
w3_2.eth.chain_id = 2
with patch.object(manager, "_configure_w3", side_effect=[w3_1, w3_2]):
assert list(manager.web3_endpoints(chain_id=2)) == [w3_1, w3_2]
@pytest.mark.parametrize(
"value, expectedValue",
[
# number string
("123132312", None),
("-1231231", None),
# big int string of form "n"
(f"{UINT256_MAX}n", UINT256_MAX),
(f"{INT256_MIN}n", INT256_MIN),
(f"{UINT256_MAX*2}n", UINT256_MAX * 2), # larger than uint256 max
(f"{INT256_MIN*2}n", INT256_MIN * 2), # smaller than in256 min
("9007199254740992n", 9007199254740992), # bigger than max safe
("-9007199254740992n", -9007199254740992), # smaller than min safe
# regular strings
("Totally a number", None),
("Totally a number that ends with n", None),
("0xdeadbeef", None),
("fallen", None),
],
)
def test_conversion_from_big_int_string(value, expectedValue):
result = check_and_convert_big_int_string_to_int(value)
if expectedValue:
assert result == expectedValue
else:
# value unchanged
assert result == value