Add a basic first implementation of the datastore with tests

Add tests for the datastore models, and change 'NotFound' to 'RecordNotFound'
pull/2099/head
tuxxy 2020-06-24 13:13:00 +02:00
parent 184b3bf563
commit cc0ca2ef91
4 changed files with 364 additions and 216 deletions

View File

@ -43,6 +43,7 @@ class DatastoreRecord:
record_id: Union[int, str], record_id: Union[int, str],
writeable: bool = False) -> None: writeable: bool = False) -> None:
self._record_id = record_id self._record_id = record_id
self._fields = [field[1:] for field in type(self).__dict__ if type(type(self).__dict__[field]) == RecordField]
self.__db_tx = db_tx self.__db_tx = db_tx
self.__writeable = writeable self.__writeable = writeable
@ -67,6 +68,7 @@ class DatastoreRecord:
""" """
# When writeable is None (meaning, it hasn't been __init__ yet), then # When writeable is None (meaning, it hasn't been __init__ yet), then
# we allow any attribute to be set on the instance. # we allow any attribute to be set on the instance.
# HOT LAVA -- causes a recursion if this check isn't present.
if self.__writeable is None: if self.__writeable is None:
super().__setattr__(attr, value) super().__setattr__(attr, value)
@ -94,6 +96,7 @@ class DatastoreRecord:
its `RecordField.field_type`, then this method will raise a `TypeError`. its `RecordField.field_type`, then this method will raise a `TypeError`.
""" """
# Handle __getattr__ look ups for private fields # Handle __getattr__ look ups for private fields
# HOT LAVA -- causes a recursion if this check isn't present.
if attr.startswith('_'): if attr.startswith('_'):
return super().__getattr__(attr) return super().__getattr__(attr)

View File

@ -14,204 +14,261 @@ GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>. along with nucypher. If not, see <https://www.gnu.org/licenses/>.
""" """
import lmdb
import maya import maya
from contextlib import contextmanager, suppress
from bytestring_splitter import BytestringSplitter from bytestring_splitter import BytestringSplitter
from datetime import datetime from typing import Union
from typing import List
from umbral.keys import UmbralPublicKey
from umbral.kfrags import KFrag
from nucypher.crypto.signing import Signature from nucypher.crypto.signing import Signature
from nucypher.crypto.utils import fingerprint_from_key from nucypher.datastore.base import DatastoreRecord, RecordField
from nucypher.datastore.db.base import DatastoreRecord, RecordField from nucypher.datastore.models import PolicyArrangement, Workorder
from nucypher.datastore.db.models import PolicyArrangement, Workorder
class NotFound(Exception): class RecordNotFound(Exception):
""" """
Exception class for Datastore calls for objects that don't exist. Exception class for Datastore calls for objects that don't exist.
""" """
pass pass
class DatastoreTransactionError(Exception):
"""
Exception class for errors during transactions in the datastore.
"""
pass
class Datastore: class Datastore:
""" """
A storage class of persistent cryptographic entities for use by Ursula. A persistent storage layer for arbitrary data for use by NuCypher characters.
""" """
kfrag_splitter = BytestringSplitter(Signature, (KFrag, KFrag.expected_bytes_length()))
def __init__(self, sqlalchemy_engine=None) -> None: # LMDB has a `map_size` arg that caps the total size of the database.
# We can set this arbitrarily high (1TB) to prevent any run-time crashes.
LMDB_MAP_SIZE = 1_000_000_000_000
def __init__(self, db_path: str) -> None:
""" """
Initializes a Datastore object. Initializes a Datastore object by path.
:param sqlalchemy_engine: SQLAlchemy engine object to create session :param db_path: Filepath to a lmdb database.
""" """
self.engine = sqlalchemy_engine self.db_path = db_path
Session = sessionmaker(bind=sqlalchemy_engine) self.__db_env = lmdb.open(db_path, map_size=self.LMDB_MAP_SIZE)
# This will probably be on the reactor thread for most production configs. @contextmanager
# Best to treat like hot lava. def describe(self, record_type: 'DatastoreRecord', record_id: Union[int, str], writeable: bool=False):
self._session_on_init_thread = Session() """
This method is used to perform CRUD operations on the datastore within
the safety of a context manager by returning an instance of the
`record_type` identified by the `record_id` provided.
@staticmethod When `writeable` is `False`, the record returned by this method
def __commit(session) -> None: cannot be used for any operations that write to the datastore. If an
attempt is made to retrieve a non-existent record whilst `writeable`
is `False`, this method raises a `RecordNotFound` error.
When `writeable` is `True`, the record can be used to perform writes
on the datastore. In the event an error occurs during the write, the
transaction will be aborted and no data will be written, and a
`DatastoreTransactionError` will be raised.
If the record is used outside the scope of the context manager, any
writes or reads will error.
"""
try: try:
session.commit() with self.__db_env.begin(write=writeable) as datastore_tx:
except OperationalError: record = record_type(datastore_tx, record_id, writeable=writeable)
session.rollback() yield record
raise except (AttributeError, TypeError) as tx_err:
# Handle `RecordNotFound` cases when `writeable` is `False`.
if not writeable and isinstance(tx_err, AttributeError):
raise RecordNotFound(tx_err)
raise DatastoreTransactionError(f'An error was encountered during the transaction (no data was written): {tx_err}')
finally:
# Set the `writeable` instance variable to `False` so that writes
# cannot be attempted on the leftover reference. This isn't really
# possible because the `datastore_tx` is no longer usable, but
# we set this to ensure some degree of safety.
record.__dict__['_DatastoreRecord__writeable'] = False
# #class Datastore:
# Arrangements # """
# # A storage class of persistent cryptographic entities for use by Ursula.
# """
def add_policy_arrangement(self, # kfrag_splitter = BytestringSplitter(Signature, (KFrag, KFrag.expected_bytes_length()))
expiration: maya.MayaDT, #
arrangement_id: bytes, # def __init__(self, sqlalchemy_engine=None) -> None:
kfrag: KFrag = None, # """
alice_verifying_key: UmbralPublicKey = None, # Initializes a Datastore object.
alice_signature: Signature = None, # TODO: Why is this unused? #
session=None # :param sqlalchemy_engine: SQLAlchemy engine object to create session
) -> PolicyArrangement: # """
""" # self.engine = sqlalchemy_engine
Creates a PolicyArrangement to the Keystore. # Session = sessionmaker(bind=sqlalchemy_engine)
#
:return: The newly added PolicyArrangement object # # This will probably be on the reactor thread for most production configs.
""" # # Best to treat like hot lava.
session = session or self._session_on_init_thread # self._session_on_init_thread = Session()
#
new_policy_arrangement = PolicyArrangement( # @staticmethod
expiration=expiration, # def __commit(session) -> None:
id=arrangement_id, # try:
kfrag=kfrag, # session.commit()
alice_verifying_key=bytes(alice_verifying_key), # except OperationalError:
alice_signature=None, # session.rollback()
# bob_verifying_key.id # TODO: Is this needed? # raise
) #
# #
session.add(new_policy_arrangement) # # Arrangements
self.__commit(session=session) # #
return new_policy_arrangement #
# def add_policy_arrangement(self,
def get_policy_arrangement(self, arrangement_id: bytes, session=None) -> PolicyArrangement: # expiration: maya.MayaDT,
""" # arrangement_id: bytes,
Retrieves a PolicyArrangement by its HRAC. # kfrag: KFrag = None,
# alice_verifying_key: UmbralPublicKey = None,
:return: The PolicyArrangement object # alice_signature: Signature = None, # TODO: Why is this unused?
""" # session=None
session = session or self._session_on_init_thread # ) -> PolicyArrangement:
policy_arrangement = session.query(PolicyArrangement).filter_by(id=arrangement_id).first() # """
if not policy_arrangement: # Creates a PolicyArrangement to the Keystore.
raise NotFound("No PolicyArrangement {} found.".format(arrangement_id)) #
return policy_arrangement # :return: The newly added PolicyArrangement object
# """
def get_all_policy_arrangements(self, session=None) -> List[PolicyArrangement]: # session = session or self._session_on_init_thread
""" #
Returns all the PolicyArrangements # new_policy_arrangement = PolicyArrangement(
# expiration=expiration,
:return: The list of PolicyArrangement objects # id=arrangement_id,
""" # kfrag=kfrag,
session = session or self._session_on_init_thread # alice_verifying_key=bytes(alice_verifying_key),
arrangements = session.query(PolicyArrangement).all() # alice_signature=None,
return arrangements # # bob_verifying_key.id # TODO: Is this needed?
# )
def attach_kfrag_to_saved_arrangement(self, alice, id_as_hex, kfrag, session=None): #
session = session or self._session_on_init_thread # session.add(new_policy_arrangement)
policy_arrangement = session.query(PolicyArrangement).filter_by(id=id_as_hex.encode()).first() # self.__commit(session=session)
# return new_policy_arrangement
if policy_arrangement is None: #
raise NotFound("Can't attach a kfrag to non-existent Arrangement {}".format(id_as_hex)) # def get_policy_arrangement(self, arrangement_id: bytes, session=None) -> PolicyArrangement:
# """
if policy_arrangement.alice_verifying_key != alice.stamp: # Retrieves a PolicyArrangement by its HRAC.
raise alice.SuspiciousActivity #
# :return: The PolicyArrangement object
policy_arrangement.kfrag = bytes(kfrag) # """
self.__commit(session=session) # session = session or self._session_on_init_thread
# policy_arrangement = session.query(PolicyArrangement).filter_by(id=arrangement_id).first()
def del_policy_arrangement(self, arrangement_id: bytes, session=None) -> int: # if not policy_arrangement:
""" # raise NotFound("No PolicyArrangement {} found.".format(arrangement_id))
Deletes a PolicyArrangement from the Keystore. # return policy_arrangement
""" #
session = session or self._session_on_init_thread # def get_all_policy_arrangements(self, session=None) -> List[PolicyArrangement]:
deleted_records = session.query(PolicyArrangement).filter_by(id=arrangement_id).delete() # """
# Returns all the PolicyArrangements
self.__commit(session=session) #
return deleted_records # :return: The list of PolicyArrangement objects
# """
def del_expired_policy_arrangements(self, session=None, now=None) -> int: # session = session or self._session_on_init_thread
""" # arrangements = session.query(PolicyArrangement).all()
Deletes all expired PolicyArrangements from the Keystore. # return arrangements
""" #
session = session or self._session_on_init_thread # def attach_kfrag_to_saved_arrangement(self, alice, id_as_hex, kfrag, session=None):
now = now or datetime.now() # session = session or self._session_on_init_thread
result = session.query(PolicyArrangement).filter(PolicyArrangement.expiration <= now) # policy_arrangement = session.query(PolicyArrangement).filter_by(id=id_as_hex.encode()).first()
#
deleted_records = 0 # if policy_arrangement is None:
if result.count() > 0: # raise NotFound("Can't attach a kfrag to non-existent Arrangement {}".format(id_as_hex))
deleted_records = result.delete() #
self.__commit(session=session) # if policy_arrangement.alice_verifying_key != alice.stamp:
return deleted_records # raise alice.SuspiciousActivity
#
# # policy_arrangement.kfrag = bytes(kfrag)
# Work Orders # self.__commit(session=session)
# #
# def del_policy_arrangement(self, arrangement_id: bytes, session=None) -> int:
def save_workorder(self, # """
bob_verifying_key: UmbralPublicKey, # Deletes a PolicyArrangement from the Keystore.
bob_signature: Signature, # """
arrangement_id: bytes, # session = session or self._session_on_init_thread
session=None # deleted_records = session.query(PolicyArrangement).filter_by(id=arrangement_id).delete()
) -> Workorder: #
""" # self.__commit(session=session)
Adds a Workorder to the keystore. # return deleted_records
""" #
session = session or self._session_on_init_thread # def del_expired_policy_arrangements(self, session=None, now=None) -> int:
# """
new_workorder = Workorder(bob_verifying_key=bytes(bob_verifying_key), # Deletes all expired PolicyArrangements from the Keystore.
bob_signature=bob_signature, # """
arrangement_id=arrangement_id) # session = session or self._session_on_init_thread
# now = now or datetime.now()
session.add(new_workorder) # result = session.query(PolicyArrangement).filter(PolicyArrangement.expiration <= now)
self.__commit(session=session) #
return new_workorder # deleted_records = 0
# if result.count() > 0:
def get_workorders(self, # deleted_records = result.delete()
arrangement_id: bytes = None, # self.__commit(session=session)
bob_verifying_key: bytes = None, # return deleted_records
session=None #
) -> List[Workorder]: # #
""" # # Work Orders
Returns a list of Workorders by HRAC. # #
""" #
session = session or self._session_on_init_thread # def save_workorder(self,
query = session.query(Workorder) # bob_verifying_key: UmbralPublicKey,
# bob_signature: Signature,
if not arrangement_id and not bob_verifying_key: # arrangement_id: bytes,
workorders = query.all() # Return all records # session=None
# ) -> Workorder:
else: # """
# Return arrangement records # Adds a Workorder to the keystore.
if arrangement_id: # """
workorders = query.filter_by(arrangement_id=arrangement_id) # session = session or self._session_on_init_thread
#
# Return records for Bob # new_workorder = Workorder(bob_verifying_key=bytes(bob_verifying_key),
else: # bob_signature=bob_signature,
workorders = query.filter_by(bob_verifying_key=bob_verifying_key) # arrangement_id=arrangement_id)
#
if not workorders: # session.add(new_workorder)
raise NotFound # self.__commit(session=session)
# return new_workorder
return list(workorders) #
# def get_workorders(self,
def del_workorders(self, arrangement_id: bytes, session=None) -> int: # arrangement_id: bytes = None,
""" # bob_verifying_key: bytes = None,
Deletes a Workorder from the Keystore. # session=None
""" # ) -> List[Workorder]:
session = session or self._session_on_init_thread # """
# Returns a list of Workorders by HRAC.
workorders = session.query(Workorder).filter_by(arrangement_id=arrangement_id) # """
deleted = workorders.delete() # session = session or self._session_on_init_thread
self.__commit(session=session) # query = session.query(Workorder)
return deleted #
# if not arrangement_id and not bob_verifying_key:
# workorders = query.all() # Return all records
#
# else:
# # Return arrangement records
# if arrangement_id:
# workorders = query.filter_by(arrangement_id=arrangement_id)
#
# # Return records for Bob
# else:
# workorders = query.filter_by(bob_verifying_key=bob_verifying_key)
#
# if not workorders:
# raise NotFound
#
# return list(workorders)
#
# def del_workorders(self, arrangement_id: bytes, session=None) -> int:
# """
# Deletes a Workorder from the Keystore.
# """
# session = session or self._session_on_init_thread
#
# workorders = session.query(Workorder).filter_by(arrangement_id=arrangement_id)
# deleted = workorders.delete()
# self.__commit(session=session)
# return deleted

View File

@ -23,23 +23,23 @@ from nucypher.datastore.base import DatastoreRecord, RecordField
class PolicyArrangement(DatastoreRecord): class PolicyArrangement(DatastoreRecord):
arrangement_id = RecordField(bytes) _arrangement_id = RecordField(bytes)
expiration = RecordField(MayaDT, _expiration = RecordField(MayaDT,
encode=lambda maya_date: maya_date.iso8601().encode(), encode=lambda maya_date: maya_date.iso8601().encode(),
decode=lambda maya_bytes: MayaDT.from_iso8601(maya_bytes.decode())) decode=lambda maya_bytes: MayaDT.from_iso8601(maya_bytes.decode()))
kfrag = RecordField(KFrag, _kfrag = RecordField(KFrag,
encode=lambda kfrag: kfrag.to_bytes(), encode=lambda kfrag: kfrag.to_bytes(),
decode=KFrag.from_bytes) decode=KFrag.from_bytes)
alice_verifying_key = RecordField(UmbralPublicKey, _alice_verifying_key = RecordField(UmbralPublicKey,
encode=bytes, encode=bytes,
decode=UmbralPublicKey.from_bytes) decode=UmbralPublicKey.from_bytes)
class Workorder(DatastoreRecord): class Workorder(DatastoreRecord):
arrangement_id = RecordField(bytes) _arrangement_id = RecordField(bytes)
bob_verifying_key = RecordField(UmbralPublicKey, _bob_verifying_key = RecordField(UmbralPublicKey,
encode=bytes, encode=bytes,
decode=UmbralPublicKey.from_bytes) decode=UmbralPublicKey.from_bytes)
bob_signature = RecordField(Signature, _bob_signature = RecordField(Signature,
encode=bytes, encode=bytes,
decode=Signature.from_bytes) decode=Signature.from_bytes)

View File

@ -26,6 +26,73 @@ from nucypher.datastore.base import DatastoreRecord, RecordField
from nucypher.datastore.models import PolicyArrangement, Workorder from nucypher.datastore.models import PolicyArrangement, Workorder
def test_datastore():
class TestRecord(DatastoreRecord):
_test = RecordField(bytes)
_test_date = RecordField(datetime,
encode=lambda val: datetime.isoformat(val).encode(),
decode=lambda val: datetime.fromisoformat(val.decode()))
temp_path = tempfile.mkdtemp()
storage = datastore.Datastore(temp_path)
assert storage.LMDB_MAP_SIZE == 1_000_000_000_000
assert storage.db_path == temp_path
assert storage._Datastore__db_env.path() == temp_path
# Test writing
# Writing to a valid field works!
with storage.describe(TestRecord, 'test_id', writeable=True) as test_record:
test_record.test = b'test data'
assert test_record.test == b'test data'
# Check that you can't reuse the record instance to write outside the context manager
with pytest.raises(TypeError):
test_record.test = b'should not write'
# Nor can you read outside the context manager
with pytest.raises(lmdb.Error):
should_error = test_record.test
# Records can also have ints as IDs
with storage.describe(TestRecord, 1337, writeable=True) as test_record:
test_record.test = b'test int ID'
assert test_record.test == b'test int ID'
# Writing to a non-existent field errors
with pytest.raises(datastore.DatastoreTransactionError):
with storage.describe(TestRecord, 'test_id', writeable=True) as test_record:
test_record.nonexistent_field = b'this will error'
# Writing the wrong type to a field errors
with pytest.raises(datastore.DatastoreTransactionError):
with storage.describe(TestRecord, 'test_id', writeable=True) as test_record:
test_record.test = 1234
# Check that nothing was written
with storage.describe(TestRecord, 'test_id') as test_record:
assert test_record.test != 1234
# An error in the context manager results in a transaction abort
with pytest.raises(datastore.DatastoreTransactionError):
with storage.describe(TestRecord, 'test_id', writeable=True) as test_record:
# Valid write
test_record.test = b'this will not persist'
# Erroneous write causing an abort
test_record.nonexistent = b'causes an error and aborts the write'
# Test reading
# Getting read-only access to a record can be done by not setting `writeable` to `True`.
# `writeable` is, by default, `False`.
# Check that nothing was written from the aborted transaction above.
with storage.describe(TestRecord, 'test_id') as test_record:
assert test_record.test == b'test data'
# In the event a record doesn't exist, this will raise a `RecordNotFound` error iff `writeable=False`.
with pytest.raises(datastore.RecordNotFound):
with storage.describe(TestRecord, 'nonexistent') as test_record:
should_error = test_record.test
def test_datastore_record_read(): def test_datastore_record_read():
class TestRecord(DatastoreRecord): class TestRecord(DatastoreRecord):
_test = RecordField(bytes) _test = RecordField(bytes)
@ -38,6 +105,7 @@ def test_datastore_record_read():
# Check the default attrs. # Check the default attrs.
test_rec = TestRecord(db_tx, 'testing', writeable=False) test_rec = TestRecord(db_tx, 'testing', writeable=False)
assert test_rec._record_id == 'testing' assert test_rec._record_id == 'testing'
assert test_rec._fields == ['test', 'test_date']
assert test_rec._DatastoreRecord__db_tx == db_tx assert test_rec._DatastoreRecord__db_tx == db_tx
assert test_rec._DatastoreRecord__writeable == False assert test_rec._DatastoreRecord__writeable == False
assert test_rec._DatastoreRecord__storagekey == 'TestRecord:{record_field}:{record_id}' assert test_rec._DatastoreRecord__storagekey == 'TestRecord:{record_field}:{record_id}'
@ -81,37 +149,57 @@ def test_datastore_record_write():
# Test writing a valid field and getting it. # Test writing a valid field and getting it.
test_rec.test = b'good write' test_rec.test = b'good write'
assert test_rec.test == b'good write' assert test_rec.test == b'good write'
assert msgpack.unpackb(db_tx.get(b'TestRecord:test:testing')) == b'good write'
# TODO: Mock a `DBWriteError` # TODO: Mock a `DBWriteError`
# Test abort
with pytest.raises(lmdb.Error):
with db_env.begin(write=True) as db_tx:
test_rec = TestRecord(db_tx, 'testing', writeable=True)
test_rec.test = b'should not be set'
db_tx.abort()
# def test_datastore_policy_arrangement_model(): # After abort, the value should still be the one before the previous `put`
# arrangement_id = b'test' with db_env.begin() as db_tx:
# expiration = maya.now() test_rec = TestRecord(db_tx, 'testing', writeable=False)
# alice_verifying_key = keypairs.SigningKeypair(generate_keys_if_needed=True).pubkey assert test_rec.test == b'good write'
#
# # TODO: Leaving out KFrag for now since I don't have an easy way to grab one.
# test_record = PolicyArrangement(arrangement_id=arrangement_id, def test_datastore_policy_arrangement_model():
# expiration=expiration, temp_path = tempfile.mkdtemp()
# alice_verifying_key=alice_verifying_key) storage = datastore.Datastore(temp_path)
#
# assert test_record.arrangement_id == arrangement_id arrangement_id_hex = 'beef'
# assert test_record.expiration == expiration expiration = maya.now()
# assert alice_verifying_key == alice_verifying_key alice_verifying_key = keypairs.SigningKeypair(generate_keys_if_needed=True).pubkey
# assert test_record == PolicyArrangement.from_bytes(test_record.to_bytes())
# # TODO: Leaving out KFrag for now since I don't have an easy way to grab one.
# with storage.describe(PolicyArrangement, arrangement_id_hex, writeable=True) as policy_arrangement:
# def test_datastore_workorder_model(): policy_arrangement.arrangement_id = bytes.fromhex(arrangement_id_hex)
# bob_keypair = keypairs.SigningKeypair(generate_keys_if_needed=True) policy_arrangement.expiration = expiration
# policy_arrangement.alice_verifying_key = alice_verifying_key
# arrangement_id = b'test'
# bob_verifying_key = bob_keypair.pubkey with storage.describe(PolicyArrangement, arrangement_id_hex) as policy_arrangement:
# bob_signature = bob_keypair.sign(b'test') assert policy_arrangement.arrangement_id == bytes.fromhex(arrangement_id_hex)
# assert policy_arrangement.expiration == expiration
# test_record = Workorder(arrangement_id=arrangement_id, assert policy_arrangement.alice_verifying_key == alice_verifying_key
# bob_verifying_key=bob_verifying_key,
# bob_signature=bob_signature)
# def test_datastore_workorder_model():
# assert test_record.arrangement_id == arrangement_id temp_path = tempfile.mkdtemp()
# assert test_record.bob_verifying_key == bob_verifying_key storage = datastore.Datastore(temp_path)
# assert test_record.bob_signature == bob_signature bob_keypair = keypairs.SigningKeypair(generate_keys_if_needed=True)
# assert test_record == Workorder.from_bytes(test_record.to_bytes())
arrangement_id_hex = 'beef'
bob_verifying_key = bob_keypair.pubkey
bob_signature = bob_keypair.sign(b'test')
with storage.describe(Workorder, arrangement_id_hex, writeable=True) as work_order:
work_order.arrangement_id = bytes.fromhex(arrangement_id_hex)
work_order.bob_verifying_key = bob_verifying_key
work_order.bob_signature = bob_signature
with storage.describe(Workorder, arrangement_id_hex) as work_order:
assert work_order.arrangement_id == bytes.fromhex(arrangement_id_hex)
assert work_order.bob_verifying_key == bob_verifying_key
assert work_order.bob_signature == bob_signature