nucypher/tests/utils/middleware.py

184 lines
6.1 KiB
Python

import random
import socket
import time
from pathlib import Path
import requests
from flask import Response
from nucypher_core import MetadataRequest, FleetStateChecksum
from nucypher.characters.lawful import Ursula
from nucypher.network.middleware import NucypherMiddlewareClient, RestMiddleware
from tests.utils.ursula import MOCK_KNOWN_URSULAS_CACHE
class BadTestUrsulas(RuntimeError):
crash_right_now = True
class _TestMiddlewareClient(NucypherMiddlewareClient):
timeout = None
@staticmethod
def response_cleaner(response):
response.content = response.data
return response
def _get_mock_client_by_ursula(self, ursula):
port = ursula.rest_interface.port
return self._get_mock_client_by_port(port)
def _get_mock_client_by_url(self, url):
port = int(url.split(":")[1])
return self._get_mock_client_by_port(port)
def _get_mock_client_by_port(self, port):
ursula = self._get_ursula_by_port(port)
rest_app = ursula.rest_app
rest_app.testing = True
mock_client = rest_app.test_client()
return mock_client
def _get_ursula_by_port(self, port):
mkuc = MOCK_KNOWN_URSULAS_CACHE
try:
return mkuc[port]
except KeyError:
raise BadTestUrsulas(
"Can't find an Ursula with port {} - did you spin up the right test ursulas?".format(port))
def parse_node_or_host_and_port(self, node=None, host=None, port=None):
if node:
if any((host, port)):
raise ValueError("Don't pass host and port if you are passing the node.")
mock_client = self._get_mock_client_by_ursula(node)
elif all((host, port)):
node = self._get_ursula_by_port(port)
mock_client = self._get_mock_client_by_port(port)
else:
raise ValueError("You need to pass either the node or a host and port.")
host, port = node.rest_interface.host, node.rest_interface.port
return host, port, mock_client
def invoke_method(self, method, url, *args, **kwargs):
_cert_location = kwargs.pop("verify") # TODO: Is this something that can be meaningfully tested?
kwargs.pop("timeout", None) # Just get rid of timeout; not needed for the test client.
response = super().invoke_method(method, url, *args, **kwargs)
return response
def clean_params(self, request_kwargs):
request_kwargs["query_string"] = request_kwargs.pop("params", {})
def get_certificate(self, port, *args, **kwargs):
ursula = self._get_ursula_by_port(port)
return ursula.certificate, Path()
class MockRestMiddleware(RestMiddleware):
_ursulas = None
_client_class = _TestMiddlewareClient
class NotEnoughMockUrsulas(Ursula.NotEnoughUrsulas):
pass
class MockRestMiddlewareForLargeFleetTests(MockRestMiddleware):
"""
A MockRestMiddleware with workaround necessary to test the conditions that arise with thousands of nodes.
"""
def get_nodes_via_rest(self,
node,
fleet_state_checksum,
announce_nodes=None):
response_bytes = node.bytestring_of_known_nodes()
r = Response(response_bytes)
r.content = r.data
return r
class SluggishLargeFleetMiddleware(MockRestMiddlewareForLargeFleetTests):
"""
Similar to above, but with added delay to simulate network latency.
"""
def put_treasure_map_on_node(self, node, *args, **kwargs):
time.sleep(random.randrange(5, 15) / 100)
result = super().put_treasure_map_on_node(node=node, *args, **kwargs)
time.sleep(random.randrange(5, 15) / 100)
return result
class _MiddlewareClientWithConnectionProblems(_TestMiddlewareClient):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ports_that_are_down = set()
self.certs_are_broken = False
def _get_ursula_by_port(self, port):
if port in self.ports_that_are_down:
raise ConnectionRefusedError
else:
return super()._get_ursula_by_port(port)
def get(self, *args, **kwargs):
if kwargs.get("path") == "public_information":
if self.certs_are_broken:
raise requests.exceptions.SSLError
port = kwargs.get("port")
if port in self.ports_that_are_down:
raise socket.gaierror
real_get = super(_TestMiddlewareClient, self).__getattr__("get")
return real_get(*args, **kwargs)
class NodeIsDownMiddleware(MockRestMiddleware):
"""
Modified middleware to emulate one node being down amongst many.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = _MiddlewareClientWithConnectionProblems()
def node_is_down(self, node):
self.client.ports_that_are_down.add(node.rest_interface.port)
def node_is_up(self, node):
self.client.ports_that_are_down.remove(node.rest_interface.port)
def all_nodes_up(self):
self.client.ports_that_are_down = set()
def all_nodes_down(self):
self.client.ports_that_are_down = set(MOCK_KNOWN_URSULAS_CACHE)
class EvilMiddleWare(MockRestMiddleware):
"""
Middleware for assholes.
"""
def propagate_shitty_interface_id(self, ursula, shitty_metadata):
"""
Try to get Ursula to propagate a malicious (or otherwise shitty) interface ID.
"""
fleet_state_checksum = FleetStateChecksum(this_node=None, other_nodes=[])
request = MetadataRequest(fleet_state_checksum=fleet_state_checksum, announce_nodes=[shitty_metadata])
response = self.client.post(node_or_sprout=ursula,
path="node_metadata",
data=bytes(request)
)
return response
def upload_arbitrary_data(self, node, path, data):
response = self.client.post(node_or_sprout=node,
path=path,
data=data)
return response