Formalize control interface decorators and fix bugs caught as a result. All I/O control validated.

pull/802/head
Kieran R. Prasch 2019-02-22 09:47:26 -07:00 committed by Kieran Prasch
parent dba872456c
commit c5aef86a1b
No known key found for this signature in database
GPG Key ID: 199AB839D4125A62
4 changed files with 72 additions and 22 deletions

View File

@ -5,13 +5,16 @@ class CharacterControlSpecification(ABC):
specifications = NotImplemented specifications = NotImplemented
class ProtocolError(ValueError): class SpecificationError(ValueError):
"""The protocol request is completely unusable""" """The protocol request is completely unusable"""
class MissingField(ProtocolError): class MissingField(SpecificationError):
"""The protocol request can be deserialized by is missing required fields""" """The protocol request can be deserialized by is missing required fields"""
class InvalidResponseField(ProtocolError): class InvalidInputField(SpecificationError):
"""Response data does not match the output specification"""
class InvalidOutputField(SpecificationError):
"""Response data does not match the output specification""" """Response data does not match the output specification"""
@classmethod @classmethod
@ -19,6 +22,6 @@ class CharacterControlSpecification(ABC):
try: try:
input_specification, output_specification = cls.specifications[interface_name] input_specification, output_specification = cls.specifications[interface_name]
except KeyError: except KeyError:
raise cls.ProtocolError(f"No Such Interface '{interface_name}'") raise cls.SpecificationError(f"No Such Control Interface '{interface_name}'")
return input_specification, output_specification return input_specification, output_specification

View File

@ -1,3 +1,4 @@
import functools
from typing import Tuple, Callable from typing import Tuple, Callable
import click import click
@ -12,21 +13,47 @@ from nucypher.crypto.kits import UmbralMessageKit
from nucypher.crypto.powers import DecryptingPower, SigningPower from nucypher.crypto.powers import DecryptingPower, SigningPower
def dict_interface(func) -> Callable:
"""Validate I/O specification for dictionary character control interfaces"""
@functools.wraps(func)
def wrapped(instance, request=None, *args, **kwargs) -> bytes:
# Get Specification
input_specification, output_specification = instance.get_specifications(interface_name=func.__name__)
if request:
instance.validate_input(request_data=request, input_specification=input_specification)
# Call Interface
response_dict = func(self=instance, request=request, *args, **kwargs)
# Output
response_dict = instance._build_response(response_data=response_dict)
instance.validate_output(response_data=response_dict, output_specification=output_specification)
return response_dict
return wrapped
def bytes_interface(func) -> Callable: def bytes_interface(func) -> Callable:
"""Manage protocol I/O validation and serialization""" """Manage protocol I/O validation and serialization"""
def wrapped(instance, request, *args, **kwargs) -> bytes: @functools.wraps(func)
def wrapped(instance, request=None, *args, **kwargs) -> bytes:
# Get Specification
input_specification, output_specification = instance.get_specifications(interface_name=func.__name__)
# Read # Read
interface_name = func.__name__ if request:
input_specification, output_specification = instance.get_specifications(interface_name=interface_name) request = instance.read(request_payload=request, input_specification=input_specification)
request_data = instance.read(request_payload=request, input_specification=input_specification)
# Inner Call # Inner Call
response_data = func(instance, request_data, *args, **kwargs) response = func(self=instance, request=request, *args, **kwargs)
# Write # Write
response_bytes = instance.write(response_data=response_data, output_specification=output_specification) response_bytes = instance.write(response_data=response, output_specification=output_specification)
return response_bytes return response_bytes
return wrapped return wrapped
@ -56,7 +83,7 @@ class AliceControl(CharacterControl, CharacterControlSpecification):
('label', 'policy_encrypting_key')) # Out ('label', 'policy_encrypting_key')) # Out
__derive_policy = (('label', ), # In __derive_policy = (('label', ), # In
('policy_encrypting_key',)) # Out ('policy_encrypting_key', 'label')) # Out
__grant = (('bob_encrypting_key', 'bob_verifying_key', 'm', 'n', 'label', 'expiration'), # In __grant = (('bob_encrypting_key', 'bob_verifying_key', 'm', 'n', 'label', 'expiration'), # In
('treasure_map', 'policy_encrypting_key', 'alice_signing_key', 'label')) # Out ('treasure_map', 'policy_encrypting_key', 'alice_signing_key', 'label')) # Out
@ -187,18 +214,21 @@ class BobControl(CharacterControlSpecification):
class AliceJSONControl(AliceControl, AliceCharacterControlJsonSerializer): class AliceJSONControl(AliceControl, AliceCharacterControlJsonSerializer):
@dict_interface
def create_policy(self, request): def create_policy(self, request):
federated_only = True # TODO: const for now federated_only = True # TODO: const for now
result = super().create_policy(**self.load_create_policy_input(request=request), federated_only=federated_only) result = super().create_policy(**self.load_create_policy_input(request=request), federated_only=federated_only)
response_data = self.dump_create_policy_output(response=result) response_data = self.dump_create_policy_output(response=result)
return response_data return response_data
def derive_policy(self, request, label: str): @dict_interface
def derive_policy(self, label: str, request=None):
label_bytes = label.encode() label_bytes = label.encode()
result = super().derive_policy(label=label_bytes) result = super().derive_policy(label=label_bytes)
response_data = self.dump_derive_policy_output(response=result) response_data = self.dump_derive_policy_output(response=result)
return response_data return response_data
@dict_interface
def grant(self, request): def grant(self, request):
result = super().grant(**self.parse_grant_input(request=request)) result = super().grant(**self.parse_grant_input(request=request))
response_data = self.dump_grant_output(response=result) response_data = self.dump_grant_output(response=result)
@ -222,6 +252,7 @@ class AliceJSONBytesControl(AliceJSONControl, AliceCharacterControlJsonSerialize
class BobJSONControl(BobControl, BobCharacterControlJSONSerializer): class BobJSONControl(BobControl, BobCharacterControlJSONSerializer):
@dict_interface
def join_policy(self, request): def join_policy(self, request):
""" """
Character control endpoint for joining a policy on the network. Character control endpoint for joining a policy on the network.
@ -230,6 +261,7 @@ class BobJSONControl(BobControl, BobCharacterControlJSONSerializer):
response = {'policy_encrypting_key': 'OK'} # FIXME response = {'policy_encrypting_key': 'OK'} # FIXME
return response return response
@dict_interface
def retrieve(self, request): def retrieve(self, request):
""" """
Character control endpoint for re-encrypting and decrypting policy data. Character control endpoint for re-encrypting and decrypting policy data.
@ -238,6 +270,7 @@ class BobJSONControl(BobControl, BobCharacterControlJSONSerializer):
response_data = self.dump_retrieve_output(response=result) response_data = self.dump_retrieve_output(response=result)
return response_data return response_data
@dict_interface
def public_keys(self, request): def public_keys(self, request):
""" """
Character control endpoint for getting Bob's encrypting and signing public keys Character control endpoint for getting Bob's encrypting and signing public keys

View File

@ -33,22 +33,35 @@ class CharacterControlJsonSerializer(CharacterControlSerializer):
raise self.SerializerError(f"Invalid serializer input types: Got {data.__class__.__name__}") raise self.SerializerError(f"Invalid serializer input types: Got {data.__class__.__name__}")
@staticmethod @staticmethod
def __build_response(response_data: dict): def _build_response(response_data: dict):
response_data = {'result': response_data, 'version': str(nucypher.__version__)} response_data = {'result': response_data, 'version': str(nucypher.__version__)}
return response_data return response_data
@staticmethod @staticmethod
def __validate_input(request_data: dict, input_specification: tuple) -> bool: def validate_input(request_data: dict, input_specification: tuple) -> bool:
# Invalid Fields
input_fields = set(request_data.keys())
extra_fields = input_fields - set(input_specification)
if extra_fields:
raise CharacterControlSpecification.InvalidInputField(f"Invalid request fields '{', '.join(extra_fields)}'."
f"Valid fields are: {', '.join(input_specification)}.")
# Missing Fields
missing_fields = list()
for field in input_specification: for field in input_specification:
if field not in request_data: if field not in request_data:
raise CharacterControlSpecification.MissingField(f"Request is missing the '{field}' field") missing_fields.append(missing_fields)
if missing_fields:
raise CharacterControlSpecification.MissingField(f"Request is missing fields: '{', '.join(missing_fields)}' field")
return True return True
@staticmethod @staticmethod
def __validate_output(response_data: dict, output_specification: tuple) -> bool: def validate_output(response_data: dict, output_specification: tuple) -> bool:
for field in output_specification: for field in output_specification:
if field not in response_data['result']: if field not in response_data['result']:
raise CharacterControlSpecification.InvalidResponseField(f"Response is missing the '{field}' field") raise CharacterControlSpecification.InvalidOutputField(f"Response is missing the '{field}' field")
return True return True
def read(self, request_payload: bytes, input_specification: tuple) -> dict: def read(self, request_payload: bytes, input_specification: tuple) -> dict:
@ -60,13 +73,13 @@ class CharacterControlJsonSerializer(CharacterControlSerializer):
except JSONDecodeError: except JSONDecodeError:
raise self.SerializerError(f"Invalid protocol input: got {request_payload}") raise self.SerializerError(f"Invalid protocol input: got {request_payload}")
self.__validate_input(request_data=request_data, input_specification=input_specification) self.validate_input(request_data=request_data, input_specification=input_specification)
return request_data return request_data
def write(self, response_data: dict, output_specification) -> bytes: def write(self, response_data: dict, output_specification) -> bytes:
response_data = self.__build_response(response_data=response_data) response_data = self._build_response(response_data=response_data)
response_payload = CharacterControlJsonSerializer._serializer(response_data) response_payload = CharacterControlJsonSerializer._serializer(response_data)
self.__validate_output(response_data=response_data, output_specification=output_specification) self.validate_output(response_data=response_data, output_specification=output_specification)
return response_payload return response_payload
@ -91,7 +104,8 @@ class AliceCharacterControlJsonSerializer(CharacterControlJsonSerializer):
@staticmethod @staticmethod
def dump_derive_policy_output(response: dict): def dump_derive_policy_output(response: dict):
policy_encrypting_key_hex = bytes(response['policy_encrypting_key']).hex() policy_encrypting_key_hex = bytes(response['policy_encrypting_key']).hex()
response_data = {'policy_encrypting_key': policy_encrypting_key_hex} unicode_label = response['label'].decode()
response_data = {'policy_encrypting_key': policy_encrypting_key_hex, 'label': unicode_label}
return response_data return response_data
@staticmethod @staticmethod

View File

@ -1,4 +1,4 @@
from constant_sorrow.constants import NO_CONTROL_PROTOCOL, NO_WSGI_APP from constant_sorrow.constants import NO_WSGI_APP
from flask import Flask, Response from flask import Flask, Response