Handle invalid metadata bytes in node storage. Unit tests.

pull/1881/head
David Núñez 2020-09-21 13:10:32 +02:00
parent c897a20d53
commit 279f07c9e4
2 changed files with 58 additions and 9 deletions

View File

@ -15,7 +15,6 @@ You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
import sqlite3
from pathlib import Path
import OpenSSL
@ -23,6 +22,8 @@ import binascii
import os
import tempfile
from abc import ABC, abstractmethod
from bytestring_splitter import BytestringSplittingError
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import Encoding
@ -305,6 +306,9 @@ class LocalFileBasedNodeStorage(NodeStorage):
class NoNodeMetadataFileFound(FileNotFoundError, NodeStorage.UnknownNode):
pass
class InvalidNodeMetadata(NodeStorage.NodeStorageError):
"""Node metadata is corrupt or not possible to parse"""
def __init__(self,
config_root: str = None,
storage_root: str = None,
@ -394,7 +398,7 @@ class LocalFileBasedNodeStorage(NodeStorage):
self.__METADATA_FILENAME_TEMPLATE.format(checksum_address))
return metadata_path
def __read_metadata(self, filepath: str, federated_only: bool):
def __read_metadata(self, filepath: str):
from nucypher.characters.lawful import Ursula
@ -402,9 +406,12 @@ class LocalFileBasedNodeStorage(NodeStorage):
with open(filepath, "rb") as seed_file:
seed_file.seek(0)
node_bytes = self._decoder(seed_file.read())
node = Ursula.from_bytes(node_bytes)
node = Ursula.from_bytes(node_bytes, fail_fast=True)
except FileNotFoundError:
raise self.UnknownNode
raise self.NoNodeMetadataFileFound
except (BytestringSplittingError, Ursula.UnexpectedVersion):
raise self.InvalidNodeMetadata
return node
def __write_metadata(self, filepath: str, node):
@ -430,10 +437,18 @@ class LocalFileBasedNodeStorage(NodeStorage):
else:
known_nodes = set()
invalid_metadata = []
for filename in filenames:
metadata_path = os.path.join(self.metadata_dir, filename)
node = self.__read_metadata(filepath=metadata_path, federated_only=federated_only) # TODO: 466
known_nodes.add(node)
try:
node = self.__read_metadata(filepath=metadata_path)
except self.NodeStorageError:
invalid_metadata.append(filename)
else:
known_nodes.add(node)
if invalid_metadata:
self.log.warn(f"Couldn't read metadata at {metadata_path} for the following files: {invalid_metadata}")
return known_nodes
@validate_checksum_address
@ -442,7 +457,7 @@ class LocalFileBasedNodeStorage(NodeStorage):
certificate = self.__read_tls_public_certificate(checksum_address=checksum_address)
return certificate
metadata_path = self.__generate_metadata_filepath(checksum_address=checksum_address)
node = self.__read_metadata(filepath=metadata_path, federated_only=federated_only) # TODO: 466
node = self.__read_metadata(filepath=metadata_path)
return node
def store_node_certificate(self, certificate: Certificate, force: bool = True):

View File

@ -15,14 +15,18 @@ You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
import os
import pytest
from nucypher.characters.lawful import Ursula
from nucypher.config.storages import ForgetfulNodeStorage, NodeStorage, TemporaryFileBasedNodeStorage
from nucypher.network.nodes import Learner
from tests.constants import MOCK_URSULA_DB_FILEPATH
from tests.utils.ursula import MOCK_URSULA_STARTING_PORT
ADDITIONAL_NODES_TO_LEARN_ABOUT = 10
class BaseTestNodeStorageBackends:
@ -49,7 +53,7 @@ class BaseTestNodeStorageBackends:
# Save more nodes
all_known_nodes = set()
for port in range(MOCK_URSULA_STARTING_PORT, MOCK_URSULA_STARTING_PORT+100):
for port in range(MOCK_URSULA_STARTING_PORT, MOCK_URSULA_STARTING_PORT + ADDITIONAL_NODES_TO_LEARN_ABOUT):
node = Ursula(rest_host='127.0.0.1', db_filepath=MOCK_URSULA_DB_FILEPATH, rest_port=port,
federated_only=True)
node_storage.store_node_metadata(node=node)
@ -58,7 +62,7 @@ class BaseTestNodeStorageBackends:
# Read all nodes from storage
all_stored_nodes = node_storage.all(federated_only=True)
all_known_nodes.add(ursula)
assert len(all_known_nodes) == len(all_stored_nodes)
assert len(all_known_nodes) == len(all_stored_nodes) == 1 + ADDITIONAL_NODES_TO_LEARN_ABOUT
known_checksums = sorted(n.checksum_address for n in all_known_nodes)
stored_checksums = sorted(n.checksum_address for n in all_stored_nodes)
@ -99,6 +103,7 @@ class BaseTestNodeStorageBackends:
def test_read_and_write_to_storage(self, light_ursula):
assert self._read_and_write_metadata(ursula=light_ursula, node_storage=self.storage_backend)
self.storage_backend.clear()
class TestInMemoryNodeStorage(BaseTestNodeStorageBackends):
@ -111,3 +116,32 @@ class TestTemporaryFileBasedNodeStorage(BaseTestNodeStorageBackends):
storage_backend = TemporaryFileBasedNodeStorage(character_class=BaseTestNodeStorageBackends.character_class,
federated_only=BaseTestNodeStorageBackends.federated_only)
storage_backend.initialize()
def test_invalid_metadata(self, light_ursula):
self._read_and_write_metadata(ursula=light_ursula, node_storage=self.storage_backend)
some_node, another_node, *other = os.listdir(self.storage_backend.metadata_dir)
# Let's break the metadata (but not the version)
metadata_path = os.path.join(self.storage_backend.metadata_dir, some_node)
with open(metadata_path, 'wb') as file:
file.write(Learner.LEARNER_VERSION.to_bytes(4, 'big') + b'invalid')
with pytest.raises(TemporaryFileBasedNodeStorage.InvalidNodeMetadata):
self.storage_backend.get(checksum_address=some_node[:-5],
federated_only=True,
certificate_only=False)
# Let's break the metadata, by putting a completely wrong version
metadata_path = os.path.join(self.storage_backend.metadata_dir, another_node)
with open(metadata_path, 'wb') as file:
file.write(b'meh') # Versions are expected to be 4 bytes, but this is 3 bytes
with pytest.raises(TemporaryFileBasedNodeStorage.InvalidNodeMetadata):
self.storage_backend.get(checksum_address=another_node[:-5],
federated_only=True,
certificate_only=False)
# Since there are 2 broken metadata files, we should get 2 nodes less when reading all
restored_nodes = self.storage_backend.all(federated_only=True, certificates_only=False)
total_nodes = 1 + ADDITIONAL_NODES_TO_LEARN_ABOUT
assert total_nodes - 2 == len(restored_nodes)