Be more eager validating return value schema.

Ensure index is non-negative integer.
pull/3334/head
derekpierre 2023-11-07 10:02:27 -05:00 committed by KPrasch
parent cda1d1b26b
commit a620a2d86f
2 changed files with 43 additions and 14 deletions

View File

@ -17,6 +17,7 @@ from marshmallow import (
validates,
validates_schema,
)
from marshmallow.validate import OneOf, Range
from packaging.version import parse as parse_version
from nucypher.policy.conditions.base import AccessControlCondition, _Serializable
@ -197,27 +198,29 @@ class NotCompoundCondition(CompoundAccessControlCondition):
super().__init__(operator=self.NOT_OPERATOR, operands=[operand])
_COMPARATOR_FUNCTIONS = {
"==": pyoperator.eq,
"!=": pyoperator.ne,
">": pyoperator.gt,
"<": pyoperator.lt,
"<=": pyoperator.le,
">=": pyoperator.ge,
}
class ReturnValueTest:
class InvalidExpression(ValueError):
pass
_COMPARATOR_FUNCTIONS = {
"==": pyoperator.eq,
"!=": pyoperator.ne,
">": pyoperator.gt,
"<": pyoperator.lt,
"<=": pyoperator.le,
">=": pyoperator.ge,
}
COMPARATORS = tuple(_COMPARATOR_FUNCTIONS)
class ReturnValueTestSchema(CamelCaseSchema):
SKIP_VALUES = (None,)
comparator = fields.Str(required=True)
comparator = fields.Str(required=True, validate=OneOf(_COMPARATOR_FUNCTIONS))
value = fields.Raw(
allow_none=False, required=True
) # any valid type (excludes None)
index = fields.Raw(allow_none=True)
index = fields.Int(strict=True, required=False, validate=Range(min=0))
@post_load
def make(self, data, **kwargs):
@ -229,9 +232,9 @@ class ReturnValueTest:
f'"{comparator}" is not a permitted comparator.'
)
if index is not None and not isinstance(index, int):
if index is not None and (not isinstance(index, int) or index < 0):
raise self.InvalidExpression(
f'"{index}" is not a permitted index. Must be a an integer.'
f'"{index}" is not a permitted index. Must be a an non-negative integer.'
)
if not is_context_variable(value):
@ -309,7 +312,7 @@ class ReturnValueTest:
processed_data = self._process_data(data, self.index)
left_operand = self._sanitize_value(processed_data)
right_operand = self._sanitize_value(self.value)
result = self._COMPARATOR_FUNCTIONS[self.comparator](left_operand, right_operand)
result = _COMPARATOR_FUNCTIONS[self.comparator](left_operand, right_operand)
return result

View File

@ -26,6 +26,12 @@ def test_return_value_test_schema():
errors = schema.validate(data=test_dict)
assert errors, f"{errors}"
# invalid comparator should cause error
test_dict = schema.dump(return_value_test)
test_dict["comparator"] = "<>"
errors = schema.validate(data=test_dict)
assert errors, f"{errors}"
# missing value should cause error
test_dict = schema.dump(return_value_test)
del test_dict["value"]
@ -38,10 +44,30 @@ def test_return_value_test_schema():
errors = schema.validate(data=test_dict)
assert not errors, f"{errors}"
# negative index should cause error
test_dict = schema.dump(return_value_test)
test_dict["index"] = -3
errors = schema.validate(data=test_dict)
assert errors, f"{errors}"
# non-integer index should cause error
test_dict = schema.dump(return_value_test)
test_dict["index"] = "25"
errors = schema.validate(data=test_dict)
assert errors, f"{errors}"
def test_return_value_index_invalid():
with pytest.raises(ReturnValueTest.InvalidExpression):
_ = ReturnValueTest(comparator=">", value="0", index="james")
_ = ReturnValueTest(comparator=">", value=0, index="james")
with pytest.raises(ReturnValueTest.InvalidExpression):
_ = ReturnValueTest(comparator=">", value=0, index=-1)
with pytest.raises(ReturnValueTest.InvalidExpression):
_ = ReturnValueTest(
comparator=">", value=0, index="10"
) # should not be a string
def test_return_value_index():