mirror of https://github.com/nucypher/nucypher.git
Be more eager validating return value schema.
Ensure index is non-negative integer.pull/3334/head
parent
cda1d1b26b
commit
a620a2d86f
|
@ -17,6 +17,7 @@ from marshmallow import (
|
|||
validates,
|
||||
validates_schema,
|
||||
)
|
||||
from marshmallow.validate import OneOf, Range
|
||||
from packaging.version import parse as parse_version
|
||||
|
||||
from nucypher.policy.conditions.base import AccessControlCondition, _Serializable
|
||||
|
@ -197,27 +198,29 @@ class NotCompoundCondition(CompoundAccessControlCondition):
|
|||
super().__init__(operator=self.NOT_OPERATOR, operands=[operand])
|
||||
|
||||
|
||||
_COMPARATOR_FUNCTIONS = {
|
||||
"==": pyoperator.eq,
|
||||
"!=": pyoperator.ne,
|
||||
">": pyoperator.gt,
|
||||
"<": pyoperator.lt,
|
||||
"<=": pyoperator.le,
|
||||
">=": pyoperator.ge,
|
||||
}
|
||||
|
||||
|
||||
class ReturnValueTest:
|
||||
class InvalidExpression(ValueError):
|
||||
pass
|
||||
|
||||
_COMPARATOR_FUNCTIONS = {
|
||||
"==": pyoperator.eq,
|
||||
"!=": pyoperator.ne,
|
||||
">": pyoperator.gt,
|
||||
"<": pyoperator.lt,
|
||||
"<=": pyoperator.le,
|
||||
">=": pyoperator.ge,
|
||||
}
|
||||
COMPARATORS = tuple(_COMPARATOR_FUNCTIONS)
|
||||
|
||||
class ReturnValueTestSchema(CamelCaseSchema):
|
||||
SKIP_VALUES = (None,)
|
||||
comparator = fields.Str(required=True)
|
||||
comparator = fields.Str(required=True, validate=OneOf(_COMPARATOR_FUNCTIONS))
|
||||
value = fields.Raw(
|
||||
allow_none=False, required=True
|
||||
) # any valid type (excludes None)
|
||||
index = fields.Raw(allow_none=True)
|
||||
index = fields.Int(strict=True, required=False, validate=Range(min=0))
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
|
@ -229,9 +232,9 @@ class ReturnValueTest:
|
|||
f'"{comparator}" is not a permitted comparator.'
|
||||
)
|
||||
|
||||
if index is not None and not isinstance(index, int):
|
||||
if index is not None and (not isinstance(index, int) or index < 0):
|
||||
raise self.InvalidExpression(
|
||||
f'"{index}" is not a permitted index. Must be a an integer.'
|
||||
f'"{index}" is not a permitted index. Must be a an non-negative integer.'
|
||||
)
|
||||
|
||||
if not is_context_variable(value):
|
||||
|
@ -309,7 +312,7 @@ class ReturnValueTest:
|
|||
processed_data = self._process_data(data, self.index)
|
||||
left_operand = self._sanitize_value(processed_data)
|
||||
right_operand = self._sanitize_value(self.value)
|
||||
result = self._COMPARATOR_FUNCTIONS[self.comparator](left_operand, right_operand)
|
||||
result = _COMPARATOR_FUNCTIONS[self.comparator](left_operand, right_operand)
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -26,6 +26,12 @@ def test_return_value_test_schema():
|
|||
errors = schema.validate(data=test_dict)
|
||||
assert errors, f"{errors}"
|
||||
|
||||
# invalid comparator should cause error
|
||||
test_dict = schema.dump(return_value_test)
|
||||
test_dict["comparator"] = "<>"
|
||||
errors = schema.validate(data=test_dict)
|
||||
assert errors, f"{errors}"
|
||||
|
||||
# missing value should cause error
|
||||
test_dict = schema.dump(return_value_test)
|
||||
del test_dict["value"]
|
||||
|
@ -38,10 +44,30 @@ def test_return_value_test_schema():
|
|||
errors = schema.validate(data=test_dict)
|
||||
assert not errors, f"{errors}"
|
||||
|
||||
# negative index should cause error
|
||||
test_dict = schema.dump(return_value_test)
|
||||
test_dict["index"] = -3
|
||||
errors = schema.validate(data=test_dict)
|
||||
assert errors, f"{errors}"
|
||||
|
||||
# non-integer index should cause error
|
||||
test_dict = schema.dump(return_value_test)
|
||||
test_dict["index"] = "25"
|
||||
errors = schema.validate(data=test_dict)
|
||||
assert errors, f"{errors}"
|
||||
|
||||
|
||||
def test_return_value_index_invalid():
|
||||
with pytest.raises(ReturnValueTest.InvalidExpression):
|
||||
_ = ReturnValueTest(comparator=">", value="0", index="james")
|
||||
_ = ReturnValueTest(comparator=">", value=0, index="james")
|
||||
|
||||
with pytest.raises(ReturnValueTest.InvalidExpression):
|
||||
_ = ReturnValueTest(comparator=">", value=0, index=-1)
|
||||
|
||||
with pytest.raises(ReturnValueTest.InvalidExpression):
|
||||
_ = ReturnValueTest(
|
||||
comparator=">", value=0, index="10"
|
||||
) # should not be a string
|
||||
|
||||
|
||||
def test_return_value_index():
|
||||
|
|
Loading…
Reference in New Issue