Add "not" operator and evaluation to CompoundAccessCondition.

pull/3293/head
derekpierre 2023-10-18 19:44:59 -04:00
parent caa11ab5d5
commit d91eacefe3
1 changed files with 53 additions and 8 deletions

View File

@ -13,12 +13,14 @@ from marshmallow import (
pre_load,
validate,
validates,
validates_schema,
)
from packaging.version import parse as parse_version
from nucypher.policy.conditions.base import AccessControlCondition, _Serializable
from nucypher.policy.conditions.context import is_context_variable
from nucypher.policy.conditions.exceptions import (
InvalidCondition,
InvalidConditionLingo,
InvalidLogicalOperator,
ReturnValueEvaluationError,
@ -76,7 +78,9 @@ class ConditionType(Enum):
class CompoundAccessControlCondition(AccessControlCondition):
AND_OPERATOR = "and"
OR_OPERATOR = "or"
OPERATORS = (AND_OPERATOR, OR_OPERATOR)
NOT_OPERATOR = "not"
OPERATORS = (AND_OPERATOR, OR_OPERATOR, NOT_OPERATOR)
CONDITION_TYPE = ConditionType.COMPOUND.value
class Schema(CamelCaseSchema):
@ -85,15 +89,30 @@ class CompoundAccessControlCondition(AccessControlCondition):
validate=validate.Equal(ConditionType.COMPOUND.value), required=True
)
name = fields.Str(required=False)
operator = fields.Str(required=True, validate=validate.OneOf(["and", "or"]))
operands = fields.List(
_ConditionField, required=True, validate=validate.Length(min=2)
)
operator = fields.Str(required=True)
operands = fields.List(_ConditionField, required=True)
# maintain field declaration ordering
class Meta:
ordered = True
@validates_schema
def validate_operator_and_operands(self, data, **kwargs):
operator = data["operator"]
if operator not in CompoundAccessControlCondition.OPERATORS:
raise InvalidLogicalOperator(f"{operator} is not a valid operator")
operands = data["operands"]
if operator == CompoundAccessControlCondition.NOT_OPERATOR:
if len(operands) != 1:
raise InvalidConditionLingo(
f"Only 1 operand permitted for '{operator}' condition"
)
elif len(operands) < 2:
raise InvalidConditionLingo(
f"Minimum of 2 operand needed for '{operator}' compound condition"
)
@post_load
def make(self, data, **kwargs):
return CompoundAccessControlCondition(**data)
@ -111,9 +130,25 @@ class CompoundAccessControlCondition(AccessControlCondition):
"operands": [CONDITION*]
}
"""
self.condition_type = condition_type
if condition_type != self.CONDITION_TYPE:
raise InvalidCondition(
f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type."
)
if operator not in self.OPERATORS:
raise InvalidLogicalOperator(f"{operator} is not a valid operator")
raise InvalidCondition(f"{operator} is not a valid operator")
if operator == self.NOT_OPERATOR:
if len(operands) != 1:
raise InvalidCondition(
f"Only 1 operand permitted for '{operator}' condition"
)
elif len(operands) < 2:
raise InvalidCondition(
f"Minimum of 2 operand needed for '{operator}' compound condition"
)
self.condition_type = condition_type
self.operator = operator
self.operands = operands
self.condition_type = condition_type
@ -134,12 +169,17 @@ class CompoundAccessControlCondition(AccessControlCondition):
# short-circuit check
if overall_result is False:
break
else:
elif self.operator == self.OR_OPERATOR:
# or operator
overall_result = overall_result or current_result
# short-circuit check
if overall_result is True:
break
elif self.operator == self.NOT_OPERATOR:
return not current_result, current_value
else:
# should never get here; raise just in case
raise ValueError(f"Invalid operator {self.operator}")
return overall_result, values
@ -154,6 +194,11 @@ class AndCompoundCondition(CompoundAccessControlCondition):
super().__init__(operator=self.AND_OPERATOR, operands=operands)
class NotCompoundCondition(CompoundAccessControlCondition):
def __init__(self, operand: AccessControlCondition):
super().__init__(operator=self.NOT_OPERATOR, operands=[operand])
class ReturnValueTest:
class InvalidExpression(ValueError):
pass