mirror of https://github.com/nucypher/nucypher.git
Formalize control interface decorators and fix bugs caught as a result. All I/O control validated.
parent
dba872456c
commit
c5aef86a1b
|
@ -5,13 +5,16 @@ class CharacterControlSpecification(ABC):
|
|||
|
||||
specifications = NotImplemented
|
||||
|
||||
class ProtocolError(ValueError):
|
||||
class SpecificationError(ValueError):
|
||||
"""The protocol request is completely unusable"""
|
||||
|
||||
class MissingField(ProtocolError):
|
||||
class MissingField(SpecificationError):
|
||||
"""The protocol request can be deserialized by is missing required fields"""
|
||||
|
||||
class InvalidResponseField(ProtocolError):
|
||||
class InvalidInputField(SpecificationError):
|
||||
"""Response data does not match the output specification"""
|
||||
|
||||
class InvalidOutputField(SpecificationError):
|
||||
"""Response data does not match the output specification"""
|
||||
|
||||
@classmethod
|
||||
|
@ -19,6 +22,6 @@ class CharacterControlSpecification(ABC):
|
|||
try:
|
||||
input_specification, output_specification = cls.specifications[interface_name]
|
||||
except KeyError:
|
||||
raise cls.ProtocolError(f"No Such Interface '{interface_name}'")
|
||||
raise cls.SpecificationError(f"No Such Control Interface '{interface_name}'")
|
||||
return input_specification, output_specification
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import functools
|
||||
from typing import Tuple, Callable
|
||||
|
||||
import click
|
||||
|
@ -12,21 +13,47 @@ from nucypher.crypto.kits import UmbralMessageKit
|
|||
from nucypher.crypto.powers import DecryptingPower, SigningPower
|
||||
|
||||
|
||||
def dict_interface(func) -> Callable:
|
||||
"""Validate I/O specification for dictionary character control interfaces"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapped(instance, request=None, *args, **kwargs) -> bytes:
|
||||
|
||||
# Get Specification
|
||||
input_specification, output_specification = instance.get_specifications(interface_name=func.__name__)
|
||||
|
||||
if request:
|
||||
instance.validate_input(request_data=request, input_specification=input_specification)
|
||||
|
||||
# Call Interface
|
||||
response_dict = func(self=instance, request=request, *args, **kwargs)
|
||||
|
||||
# Output
|
||||
response_dict = instance._build_response(response_data=response_dict)
|
||||
instance.validate_output(response_data=response_dict, output_specification=output_specification)
|
||||
|
||||
return response_dict
|
||||
return wrapped
|
||||
|
||||
|
||||
def bytes_interface(func) -> Callable:
|
||||
"""Manage protocol I/O validation and serialization"""
|
||||
|
||||
def wrapped(instance, request, *args, **kwargs) -> bytes:
|
||||
@functools.wraps(func)
|
||||
def wrapped(instance, request=None, *args, **kwargs) -> bytes:
|
||||
|
||||
# Get Specification
|
||||
input_specification, output_specification = instance.get_specifications(interface_name=func.__name__)
|
||||
|
||||
# Read
|
||||
interface_name = func.__name__
|
||||
input_specification, output_specification = instance.get_specifications(interface_name=interface_name)
|
||||
request_data = instance.read(request_payload=request, input_specification=input_specification)
|
||||
if request:
|
||||
request = instance.read(request_payload=request, input_specification=input_specification)
|
||||
|
||||
# Inner Call
|
||||
response_data = func(instance, request_data, *args, **kwargs)
|
||||
response = func(self=instance, request=request, *args, **kwargs)
|
||||
|
||||
# Write
|
||||
response_bytes = instance.write(response_data=response_data, output_specification=output_specification)
|
||||
response_bytes = instance.write(response_data=response, output_specification=output_specification)
|
||||
|
||||
return response_bytes
|
||||
return wrapped
|
||||
|
@ -56,7 +83,7 @@ class AliceControl(CharacterControl, CharacterControlSpecification):
|
|||
('label', 'policy_encrypting_key')) # Out
|
||||
|
||||
__derive_policy = (('label', ), # In
|
||||
('policy_encrypting_key',)) # Out
|
||||
('policy_encrypting_key', 'label')) # Out
|
||||
|
||||
__grant = (('bob_encrypting_key', 'bob_verifying_key', 'm', 'n', 'label', 'expiration'), # In
|
||||
('treasure_map', 'policy_encrypting_key', 'alice_signing_key', 'label')) # Out
|
||||
|
@ -187,18 +214,21 @@ class BobControl(CharacterControlSpecification):
|
|||
|
||||
class AliceJSONControl(AliceControl, AliceCharacterControlJsonSerializer):
|
||||
|
||||
@dict_interface
|
||||
def create_policy(self, request):
|
||||
federated_only = True # TODO: const for now
|
||||
result = super().create_policy(**self.load_create_policy_input(request=request), federated_only=federated_only)
|
||||
response_data = self.dump_create_policy_output(response=result)
|
||||
return response_data
|
||||
|
||||
def derive_policy(self, request, label: str):
|
||||
@dict_interface
|
||||
def derive_policy(self, label: str, request=None):
|
||||
label_bytes = label.encode()
|
||||
result = super().derive_policy(label=label_bytes)
|
||||
response_data = self.dump_derive_policy_output(response=result)
|
||||
return response_data
|
||||
|
||||
@dict_interface
|
||||
def grant(self, request):
|
||||
result = super().grant(**self.parse_grant_input(request=request))
|
||||
response_data = self.dump_grant_output(response=result)
|
||||
|
@ -222,6 +252,7 @@ class AliceJSONBytesControl(AliceJSONControl, AliceCharacterControlJsonSerialize
|
|||
|
||||
class BobJSONControl(BobControl, BobCharacterControlJSONSerializer):
|
||||
|
||||
@dict_interface
|
||||
def join_policy(self, request):
|
||||
"""
|
||||
Character control endpoint for joining a policy on the network.
|
||||
|
@ -230,6 +261,7 @@ class BobJSONControl(BobControl, BobCharacterControlJSONSerializer):
|
|||
response = {'policy_encrypting_key': 'OK'} # FIXME
|
||||
return response
|
||||
|
||||
@dict_interface
|
||||
def retrieve(self, request):
|
||||
"""
|
||||
Character control endpoint for re-encrypting and decrypting policy data.
|
||||
|
@ -238,6 +270,7 @@ class BobJSONControl(BobControl, BobCharacterControlJSONSerializer):
|
|||
response_data = self.dump_retrieve_output(response=result)
|
||||
return response_data
|
||||
|
||||
@dict_interface
|
||||
def public_keys(self, request):
|
||||
"""
|
||||
Character control endpoint for getting Bob's encrypting and signing public keys
|
||||
|
|
|
@ -33,22 +33,35 @@ class CharacterControlJsonSerializer(CharacterControlSerializer):
|
|||
raise self.SerializerError(f"Invalid serializer input types: Got {data.__class__.__name__}")
|
||||
|
||||
@staticmethod
|
||||
def __build_response(response_data: dict):
|
||||
def _build_response(response_data: dict):
|
||||
response_data = {'result': response_data, 'version': str(nucypher.__version__)}
|
||||
return response_data
|
||||
|
||||
@staticmethod
|
||||
def __validate_input(request_data: dict, input_specification: tuple) -> bool:
|
||||
def validate_input(request_data: dict, input_specification: tuple) -> bool:
|
||||
|
||||
# Invalid Fields
|
||||
input_fields = set(request_data.keys())
|
||||
extra_fields = input_fields - set(input_specification)
|
||||
|
||||
if extra_fields:
|
||||
raise CharacterControlSpecification.InvalidInputField(f"Invalid request fields '{', '.join(extra_fields)}'."
|
||||
f"Valid fields are: {', '.join(input_specification)}.")
|
||||
|
||||
# Missing Fields
|
||||
missing_fields = list()
|
||||
for field in input_specification:
|
||||
if field not in request_data:
|
||||
raise CharacterControlSpecification.MissingField(f"Request is missing the '{field}' field")
|
||||
missing_fields.append(missing_fields)
|
||||
if missing_fields:
|
||||
raise CharacterControlSpecification.MissingField(f"Request is missing fields: '{', '.join(missing_fields)}' field")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def __validate_output(response_data: dict, output_specification: tuple) -> bool:
|
||||
def validate_output(response_data: dict, output_specification: tuple) -> bool:
|
||||
for field in output_specification:
|
||||
if field not in response_data['result']:
|
||||
raise CharacterControlSpecification.InvalidResponseField(f"Response is missing the '{field}' field")
|
||||
raise CharacterControlSpecification.InvalidOutputField(f"Response is missing the '{field}' field")
|
||||
return True
|
||||
|
||||
def read(self, request_payload: bytes, input_specification: tuple) -> dict:
|
||||
|
@ -60,13 +73,13 @@ class CharacterControlJsonSerializer(CharacterControlSerializer):
|
|||
except JSONDecodeError:
|
||||
raise self.SerializerError(f"Invalid protocol input: got {request_payload}")
|
||||
|
||||
self.__validate_input(request_data=request_data, input_specification=input_specification)
|
||||
self.validate_input(request_data=request_data, input_specification=input_specification)
|
||||
return request_data
|
||||
|
||||
def write(self, response_data: dict, output_specification) -> bytes:
|
||||
response_data = self.__build_response(response_data=response_data)
|
||||
response_data = self._build_response(response_data=response_data)
|
||||
response_payload = CharacterControlJsonSerializer._serializer(response_data)
|
||||
self.__validate_output(response_data=response_data, output_specification=output_specification)
|
||||
self.validate_output(response_data=response_data, output_specification=output_specification)
|
||||
return response_payload
|
||||
|
||||
|
||||
|
@ -91,7 +104,8 @@ class AliceCharacterControlJsonSerializer(CharacterControlJsonSerializer):
|
|||
@staticmethod
|
||||
def dump_derive_policy_output(response: dict):
|
||||
policy_encrypting_key_hex = bytes(response['policy_encrypting_key']).hex()
|
||||
response_data = {'policy_encrypting_key': policy_encrypting_key_hex}
|
||||
unicode_label = response['label'].decode()
|
||||
response_data = {'policy_encrypting_key': policy_encrypting_key_hex, 'label': unicode_label}
|
||||
return response_data
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from constant_sorrow.constants import NO_CONTROL_PROTOCOL, NO_WSGI_APP
|
||||
from constant_sorrow.constants import NO_WSGI_APP
|
||||
from flask import Flask, Response
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue