diff --git a/nucypher/characters/control/interfaces.py b/nucypher/characters/control/interfaces.py index df299bc4a..cab9fb728 100644 --- a/nucypher/characters/control/interfaces.py +++ b/nucypher/characters/control/interfaces.py @@ -25,7 +25,9 @@ def character_control_interface(func): # Get specification interface_name = func.__name__ - input_specification, output_specification = instance.get_specifications(interface_name=interface_name) + input_specification, optional_specification, output_specification =\ + instance.get_specifications(interface_name=interface_name) + # XXX if request and instance.serialize: @@ -234,4 +236,3 @@ class EnricoInterface(CharacterPublicInterface, EnricoSpecification): message_kit, signature = self.enrico.encrypt_message(bytes(message, encoding='utf-8')) response_data = {'message_kit': message_kit, 'signature': signature} return response_data - diff --git a/nucypher/characters/control/specifications.py b/nucypher/characters/control/specifications.py index fc7f95fd3..3fa30ae62 100644 --- a/nucypher/characters/control/specifications.py +++ b/nucypher/characters/control/specifications.py @@ -22,11 +22,13 @@ class CharacterSpecification(ABC): if cls._specifications is NotImplemented: raise NotImplementedError("Missing specifications for character") try: - input_specification, output_specification = cls.specifications()[interface_name] + spec = cls.specifications()[interface_name] except KeyError: raise cls.SpecificationError(f"{cls.__class__.__name__} has no such control interface: '{interface_name}'") - return input_specification, output_specification + return spec.get('in', tuple()),\ + spec.get('optional', tuple()),\ + spec.get('out', tuple()) @classmethod def specifications(cls): @@ -36,8 +38,9 @@ class CharacterSpecification(ABC): return cls._specifications @staticmethod - def __validate(specification: tuple, data: dict, error_class): - invalid_fields = set(data.keys()) - set(specification) + def __validate(specification: tuple, data: dict, error_class, + optional_specification: tuple = ()): + invalid_fields = set(data.keys()) - set(specification) - set(optional_specification) if invalid_fields: pretty_invalid_fields = ', '.join(invalid_fields) raise error_class(f"Got: {pretty_invalid_fields}") @@ -50,35 +53,37 @@ class CharacterSpecification(ABC): return True def validate_request(self, interface_name: str, request: dict) -> bool: - input_specification, _ = self.get_specifications(interface_name=interface_name) - return self.__validate(specification=input_specification, data=request, error_class=self.InvalidInputField) + input_specification, optional_specification, _ = self.get_specifications(interface_name=interface_name) + return self.__validate(specification=input_specification, + optional_specification=optional_specification, + data=request, error_class=self.InvalidInputField) def validate_response(self, interface_name: str, response: dict) -> bool: - _, output_specification = self.get_specifications(interface_name=interface_name) + _, _, output_specification = self.get_specifications(interface_name=interface_name) return self.__validate(specification=output_specification, data=response, error_class=self.InvalidInputField) class AliceSpecification(CharacterSpecification): - __create_policy = (('bob_encrypting_key', 'bob_verifying_key', 'm', 'n', 'label', 'expiration'), # In - ('label', 'policy_encrypting_key')) # Out + __create_policy = {'in': ('bob_encrypting_key', 'bob_verifying_key', 'm', 'n', 'label', 'expiration'), + 'optional': ('value'), + 'out': ('label', 'policy_encrypting_key')} - __derive_policy_encrypting_key = (('label', ), # In - ('policy_encrypting_key', 'label')) # Out + __derive_policy_encrypting_key = {'in': ('label', ), + 'out': ('policy_encrypting_key', 'label')} - __grant = (('bob_encrypting_key', 'bob_verifying_key', 'm', 'n', 'label', 'expiration'), # In - ('treasure_map', 'policy_encrypting_key', 'alice_verifying_key')) # Out + __grant = {'in': ('bob_encrypting_key', 'bob_verifying_key', 'm', 'n', 'label', 'expiration'), + 'optional': ('value'), + 'out': ('treasure_map', 'policy_encrypting_key', 'alice_verifying_key')} - __revoke = (('label', 'bob_verifying_key', ), # In - ('failed_revocations',)) # Out + __revoke = {'in': ('label', 'bob_verifying_key', ), + 'out': ('failed_revocations',)} - __decrypt = ( - ('label', 'message_kit'), # In - ('cleartexts', ), # Out - ) + __decrypt = {'in': ('label', 'message_kit'), + 'out': ('cleartexts', )} - __public_keys = ((), - ('alice_verifying_key',)) + __public_keys = {'in': (), + 'out': ('alice_verifying_key',)} _specifications = {'create_policy': __create_policy, # type: Tuple[Tuple[str]] 'derive_policy_encrypting_key': __derive_policy_encrypting_key, @@ -90,14 +95,14 @@ class AliceSpecification(CharacterSpecification): class BobSpecification(CharacterSpecification): - __join_policy = (('label', 'alice_verifying_key'), - ('policy_encrypting_key', )) + __join_policy = {'in': ('label', 'alice_verifying_key'), + 'out': ('policy_encrypting_key', )} - __retrieve = (('label', 'policy_encrypting_key', 'alice_verifying_key', 'message_kit'), - ('cleartexts', )) + __retrieve = {'in': ('label', 'policy_encrypting_key', 'alice_verifying_key', 'message_kit'), + 'out': ('cleartexts', )} - __public_keys = ((), - ('bob_encrypting_key', 'bob_verifying_key')) + __public_keys = {'in': (), + 'out': ('bob_encrypting_key', 'bob_verifying_key')} _specifications = {'join_policy': __join_policy, 'retrieve': __retrieve, @@ -106,7 +111,7 @@ class BobSpecification(CharacterSpecification): class EnricoSpecification(CharacterSpecification): - __encrypt_message = (('message', ), - ('message_kit', 'signature')) + __encrypt_message = {'in': ('message', ), + 'out': ('message_kit', 'signature')} _specifications = {'encrypt_message': __encrypt_message} diff --git a/tests/characters/control/test_rpc_control.py b/tests/characters/control/test_rpc_control.py index 2e1d94992..de36a1e59 100644 --- a/tests/characters/control/test_rpc_control.py +++ b/tests/characters/control/test_rpc_control.py @@ -8,7 +8,7 @@ enrico_specification = EnricoSpecification() def validate_json_rpc_response_data(response, method_name, specification): - _input_fields, required_output_fileds = specification.get_specifications(interface_name=method_name) + _input_fields, _optional_fields, required_output_fileds = specification.get_specifications(interface_name=method_name) assert 'jsonrpc' in response.data for output_field in required_output_fileds: assert output_field in response.content @@ -22,7 +22,7 @@ def test_alice_rpc_character_control_create_policy(alice_rpc_test_client, create assert rpc_response.success is True assert rpc_response.id == 1 - _input_fields, required_output_fileds = alice_specification.get_specifications(interface_name=method_name) + _input_fields, _optional_fields, required_output_fileds = alice_specification.get_specifications(interface_name=method_name) assert 'jsonrpc' in rpc_response.data for output_field in required_output_fileds: