From d91eacefe3593d37a6435fb56f6f221edee47f4b Mon Sep 17 00:00:00 2001 From: derekpierre Date: Wed, 18 Oct 2023 19:44:59 -0400 Subject: [PATCH] Add "not" operator and evaluation to CompoundAccessCondition. --- nucypher/policy/conditions/lingo.py | 61 +++++++++++++++++++++++++---- 1 file changed, 53 insertions(+), 8 deletions(-) diff --git a/nucypher/policy/conditions/lingo.py b/nucypher/policy/conditions/lingo.py index ee5580be8..785d07a4f 100644 --- a/nucypher/policy/conditions/lingo.py +++ b/nucypher/policy/conditions/lingo.py @@ -13,12 +13,14 @@ from marshmallow import ( pre_load, validate, validates, + validates_schema, ) from packaging.version import parse as parse_version from nucypher.policy.conditions.base import AccessControlCondition, _Serializable from nucypher.policy.conditions.context import is_context_variable from nucypher.policy.conditions.exceptions import ( + InvalidCondition, InvalidConditionLingo, InvalidLogicalOperator, ReturnValueEvaluationError, @@ -76,7 +78,9 @@ class ConditionType(Enum): class CompoundAccessControlCondition(AccessControlCondition): AND_OPERATOR = "and" OR_OPERATOR = "or" - OPERATORS = (AND_OPERATOR, OR_OPERATOR) + NOT_OPERATOR = "not" + + OPERATORS = (AND_OPERATOR, OR_OPERATOR, NOT_OPERATOR) CONDITION_TYPE = ConditionType.COMPOUND.value class Schema(CamelCaseSchema): @@ -85,15 +89,30 @@ class CompoundAccessControlCondition(AccessControlCondition): validate=validate.Equal(ConditionType.COMPOUND.value), required=True ) name = fields.Str(required=False) - operator = fields.Str(required=True, validate=validate.OneOf(["and", "or"])) - operands = fields.List( - _ConditionField, required=True, validate=validate.Length(min=2) - ) + operator = fields.Str(required=True) + operands = fields.List(_ConditionField, required=True) # maintain field declaration ordering class Meta: ordered = True + @validates_schema + def validate_operator_and_operands(self, data, **kwargs): + operator = data["operator"] + if operator not in CompoundAccessControlCondition.OPERATORS: + raise InvalidLogicalOperator(f"{operator} is not a valid operator") + + operands = data["operands"] + if operator == CompoundAccessControlCondition.NOT_OPERATOR: + if len(operands) != 1: + raise InvalidConditionLingo( + f"Only 1 operand permitted for '{operator}' condition" + ) + elif len(operands) < 2: + raise InvalidConditionLingo( + f"Minimum of 2 operand needed for '{operator}' compound condition" + ) + @post_load def make(self, data, **kwargs): return CompoundAccessControlCondition(**data) @@ -111,9 +130,25 @@ class CompoundAccessControlCondition(AccessControlCondition): "operands": [CONDITION*] } """ - self.condition_type = condition_type + if condition_type != self.CONDITION_TYPE: + raise InvalidCondition( + f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type." + ) + if operator not in self.OPERATORS: - raise InvalidLogicalOperator(f"{operator} is not a valid operator") + raise InvalidCondition(f"{operator} is not a valid operator") + + if operator == self.NOT_OPERATOR: + if len(operands) != 1: + raise InvalidCondition( + f"Only 1 operand permitted for '{operator}' condition" + ) + elif len(operands) < 2: + raise InvalidCondition( + f"Minimum of 2 operand needed for '{operator}' compound condition" + ) + + self.condition_type = condition_type self.operator = operator self.operands = operands self.condition_type = condition_type @@ -134,12 +169,17 @@ class CompoundAccessControlCondition(AccessControlCondition): # short-circuit check if overall_result is False: break - else: + elif self.operator == self.OR_OPERATOR: # or operator overall_result = overall_result or current_result # short-circuit check if overall_result is True: break + elif self.operator == self.NOT_OPERATOR: + return not current_result, current_value + else: + # should never get here; raise just in case + raise ValueError(f"Invalid operator {self.operator}") return overall_result, values @@ -154,6 +194,11 @@ class AndCompoundCondition(CompoundAccessControlCondition): super().__init__(operator=self.AND_OPERATOR, operands=operands) +class NotCompoundCondition(CompoundAccessControlCondition): + def __init__(self, operand: AccessControlCondition): + super().__init__(operator=self.NOT_OPERATOR, operands=[operand]) + + class ReturnValueTest: class InvalidExpression(ValueError): pass