mirror of https://github.com/nucypher/nucypher.git
Add "not" operator and evaluation to CompoundAccessCondition.
parent
caa11ab5d5
commit
d91eacefe3
|
@ -13,12 +13,14 @@ from marshmallow import (
|
|||
pre_load,
|
||||
validate,
|
||||
validates,
|
||||
validates_schema,
|
||||
)
|
||||
from packaging.version import parse as parse_version
|
||||
|
||||
from nucypher.policy.conditions.base import AccessControlCondition, _Serializable
|
||||
from nucypher.policy.conditions.context import is_context_variable
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
InvalidCondition,
|
||||
InvalidConditionLingo,
|
||||
InvalidLogicalOperator,
|
||||
ReturnValueEvaluationError,
|
||||
|
@ -76,7 +78,9 @@ class ConditionType(Enum):
|
|||
class CompoundAccessControlCondition(AccessControlCondition):
|
||||
AND_OPERATOR = "and"
|
||||
OR_OPERATOR = "or"
|
||||
OPERATORS = (AND_OPERATOR, OR_OPERATOR)
|
||||
NOT_OPERATOR = "not"
|
||||
|
||||
OPERATORS = (AND_OPERATOR, OR_OPERATOR, NOT_OPERATOR)
|
||||
CONDITION_TYPE = ConditionType.COMPOUND.value
|
||||
|
||||
class Schema(CamelCaseSchema):
|
||||
|
@ -85,15 +89,30 @@ class CompoundAccessControlCondition(AccessControlCondition):
|
|||
validate=validate.Equal(ConditionType.COMPOUND.value), required=True
|
||||
)
|
||||
name = fields.Str(required=False)
|
||||
operator = fields.Str(required=True, validate=validate.OneOf(["and", "or"]))
|
||||
operands = fields.List(
|
||||
_ConditionField, required=True, validate=validate.Length(min=2)
|
||||
)
|
||||
operator = fields.Str(required=True)
|
||||
operands = fields.List(_ConditionField, required=True)
|
||||
|
||||
# maintain field declaration ordering
|
||||
class Meta:
|
||||
ordered = True
|
||||
|
||||
@validates_schema
|
||||
def validate_operator_and_operands(self, data, **kwargs):
|
||||
operator = data["operator"]
|
||||
if operator not in CompoundAccessControlCondition.OPERATORS:
|
||||
raise InvalidLogicalOperator(f"{operator} is not a valid operator")
|
||||
|
||||
operands = data["operands"]
|
||||
if operator == CompoundAccessControlCondition.NOT_OPERATOR:
|
||||
if len(operands) != 1:
|
||||
raise InvalidConditionLingo(
|
||||
f"Only 1 operand permitted for '{operator}' condition"
|
||||
)
|
||||
elif len(operands) < 2:
|
||||
raise InvalidConditionLingo(
|
||||
f"Minimum of 2 operand needed for '{operator}' compound condition"
|
||||
)
|
||||
|
||||
@post_load
|
||||
def make(self, data, **kwargs):
|
||||
return CompoundAccessControlCondition(**data)
|
||||
|
@ -111,9 +130,25 @@ class CompoundAccessControlCondition(AccessControlCondition):
|
|||
"operands": [CONDITION*]
|
||||
}
|
||||
"""
|
||||
self.condition_type = condition_type
|
||||
if condition_type != self.CONDITION_TYPE:
|
||||
raise InvalidCondition(
|
||||
f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type."
|
||||
)
|
||||
|
||||
if operator not in self.OPERATORS:
|
||||
raise InvalidLogicalOperator(f"{operator} is not a valid operator")
|
||||
raise InvalidCondition(f"{operator} is not a valid operator")
|
||||
|
||||
if operator == self.NOT_OPERATOR:
|
||||
if len(operands) != 1:
|
||||
raise InvalidCondition(
|
||||
f"Only 1 operand permitted for '{operator}' condition"
|
||||
)
|
||||
elif len(operands) < 2:
|
||||
raise InvalidCondition(
|
||||
f"Minimum of 2 operand needed for '{operator}' compound condition"
|
||||
)
|
||||
|
||||
self.condition_type = condition_type
|
||||
self.operator = operator
|
||||
self.operands = operands
|
||||
self.condition_type = condition_type
|
||||
|
@ -134,12 +169,17 @@ class CompoundAccessControlCondition(AccessControlCondition):
|
|||
# short-circuit check
|
||||
if overall_result is False:
|
||||
break
|
||||
else:
|
||||
elif self.operator == self.OR_OPERATOR:
|
||||
# or operator
|
||||
overall_result = overall_result or current_result
|
||||
# short-circuit check
|
||||
if overall_result is True:
|
||||
break
|
||||
elif self.operator == self.NOT_OPERATOR:
|
||||
return not current_result, current_value
|
||||
else:
|
||||
# should never get here; raise just in case
|
||||
raise ValueError(f"Invalid operator {self.operator}")
|
||||
|
||||
return overall_result, values
|
||||
|
||||
|
@ -154,6 +194,11 @@ class AndCompoundCondition(CompoundAccessControlCondition):
|
|||
super().__init__(operator=self.AND_OPERATOR, operands=operands)
|
||||
|
||||
|
||||
class NotCompoundCondition(CompoundAccessControlCondition):
|
||||
def __init__(self, operand: AccessControlCondition):
|
||||
super().__init__(operator=self.NOT_OPERATOR, operands=[operand])
|
||||
|
||||
|
||||
class ReturnValueTest:
|
||||
class InvalidExpression(ValueError):
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue