diff --git a/nucypher/datastore/base.py b/nucypher/datastore/base.py new file mode 100644 index 000000000..0aa3d376d --- /dev/null +++ b/nucypher/datastore/base.py @@ -0,0 +1,141 @@ +""" +This file is part of nucypher. + +nucypher is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +nucypher is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with nucypher. If not, see . +""" +import msgpack +from typing import Any, Callable, Iterable, NamedTuple, Optional, Union + + +class DBWriteError(Exception): + """ + Exception class for when db writes fail. + """ + pass + + +class RecordField(NamedTuple): + field_type: Any + encode: Optional[Callable] = lambda field: field + decode: Optional[Callable] = lambda field: field + + +class DatastoreRecord: + def __new__(cls, *args, **kwargs): + # Set default class attributes for the new instance + cls.__writeable = None + cls.__storagekey = f'{cls.__name__}:{{record_field}}:{{record_id}}' + return super().__new__(cls) + + def __init__(self, + db_tx: 'lmdb.Transaction', + record_id: Union[int, str], + writeable: bool = False) -> None: + self._record_id = record_id + self.__db_tx = db_tx + self.__writeable = writeable + + def __setattr__(self, attr: str, value: Any) -> None: + """ + This method is called when setting attributes on the class. We override + this method to serialize the value being set to the attribute, and then + we _write_ it to the database. + + When `__writeable` is `None`, we only set attributes on the instance. + When `__writeable` is `False`, we raise a `TypeError`. + + Finally, when `__writeable` is `True`, we get the `RecordField` for + the corresponding `attr` and check that the `value` being set is + the correct type via its `RecordField.field_type`. If the type is not + correct, we raise a `TypeError`. + + If the type is correct, we then serialize it to bytes via its + `RecordField.encode` function and pack it with msgpack. Then the value + gets written to the database. If the value is unable to be written, + this will raise a `DBWriteError`. + """ + # When writeable is None (meaning, it hasn't been __init__ yet), then + # we allow any attribute to be set on the instance. + if self.__writeable is None: + super().__setattr__(attr, value) + + # Datastore records are not writeable/mutable by default, so we + # raise a TypeError in the event that writeable is False. + elif self.__writeable is False: + raise TypeError("This datastore record isn't writeable.") + + # A datastore record is only mutated iff writeable is True. + elif self.__writeable is True: + record_field = self.__get_record_field(attr) + if not type(value) == record_field.field_type: + raise TypeError(f'Given record is type {type(value)}; expected {record_field.field_type}') + field_value = msgpack.packb(record_field.encode(value)) + self.__write_raw_record(attr, field_value) + + def __getattr__(self, attr: str) -> Any: + """ + This method is called when accessing attributes that don't exist on the + class. We override this method to _read_ from the database and return + a deserialized record. + + We deserialize records by calling the record's respective `RecordField.decode` + function. If the deserialized type doesn't match the type defined by + its `RecordField.field_type`, then this method will raise a `TypeError`. + """ + # Handle __getattr__ look ups for private fields + if attr.startswith('_'): + return super().__getattr__(attr) + + # Get the corresponding RecordField and retrieve the raw value from + # the db, unpack it, then use the `RecordField` to deserialize it. + record_field = self.__get_record_field(attr) + field_value = record_field.decode(msgpack.unpackb(self.__retrieve_raw_record(attr))) + if not type(field_value) == record_field.field_type: + raise TypeError(f"Decoded record was type {type(field_value)}; expected {record_field.field_type}") + return field_value + + def __retrieve_raw_record(self, record_field: str) -> bytes: + """ + Retrieves a raw record, as bytes, from the database given a `record_field`. + If the record doesn't exist, this method raises an `AttributeError`. + """ + key = self.__storagekey.format(record_field=record_field, record_id=self._record_id).encode() + field_value = self.__db_tx.get(key, default=None) + if field_value is None: + raise AttributeError(f"No {record_field} record found for ID: {self._record_id}.") + return field_value + + def __write_raw_record(self, record_field: str, value: bytes) -> None: + """ + Writes a raw record, as bytes, to the database given a `record_field` + and a `value`. + If the record is unable to be written, this method raises a `DBWriteError`. + """ + key = self.__storagekey.format(record_field=record_field, record_id=self._record_id).encode() + if not self.__db_tx.put(key, value, overwrite=True): + raise DBWriteError("Couldn't write the record to the database.") + + def __get_record_field(self, attr: str) -> 'RecordField': + """ + Uses `getattr` to return the `RecordField` object for a given + attribute. + These objects are accessed via class attrs as `_{attribute}`. If the + `RecordField` doesn't exist for a given `attr`, then this method will + raise a `TypeError`. + """ + try: + record_field = getattr(self, f'_{attr}') + except AttributeError: + raise TypeError(f'No valid RecordField found on {self} for {attr}.') + return record_field diff --git a/nucypher/datastore/datastore.py b/nucypher/datastore/datastore.py index f04a54855..4f9fae576 100644 --- a/nucypher/datastore/datastore.py +++ b/nucypher/datastore/datastore.py @@ -19,15 +19,14 @@ along with nucypher. If not, see . import maya from bytestring_splitter import BytestringSplitter from datetime import datetime -from sqlalchemy.exc import OperationalError -from sqlalchemy.orm import sessionmaker from typing import List from umbral.keys import UmbralPublicKey from umbral.kfrags import KFrag from nucypher.crypto.signing import Signature from nucypher.crypto.utils import fingerprint_from_key -from nucypher.datastore.db.models import Key, PolicyArrangement, Workorder +from nucypher.datastore.db.base import DatastoreRecord, RecordField +from nucypher.datastore.db.models import PolicyArrangement, Workorder class NotFound(Exception): @@ -64,57 +63,6 @@ class Datastore: session.rollback() raise - # - # Keys - # - - def add_key(self, - key: UmbralPublicKey, - is_signing: bool = True, - session=None - ) -> Key: - """ - :param key: Keypair object to store in the keystore. - - :return: The newly added key object. - """ - session = session or self._session_on_init_thread - fingerprint = fingerprint_from_key(key) - key_data = bytes(key) - new_key = Key(fingerprint, key_data, is_signing) - - session.add(new_key) - self.__commit(session=session) - return new_key - - def get_key(self, fingerprint: bytes, session=None) -> UmbralPublicKey: - """ - Returns a key from the Datastore. - - :param fingerprint: Fingerprint, in bytes, of key to return - - :return: Keypair of the returned key. - """ - session = session or self._session_on_init_thread - - key = session.query(Key).filter_by(fingerprint=fingerprint).first() - if not key: - raise NotFound("No key with fingerprint {} found.".format(fingerprint)) - - pubkey = UmbralPublicKey.from_bytes(key.key_data) - return pubkey - - def del_key(self, fingerprint: bytes, session=None): - """ - Deletes a key from the Datastore. - - :param fingerprint: Fingerprint of key to delete - """ - session = session or self._session_on_init_thread - - session.query(Key).filter_by(fingerprint=fingerprint).delete() - self.__commit(session=session) - # # Arrangements # @@ -134,15 +82,11 @@ class Datastore: """ session = session or self._session_on_init_thread - alice_key_instance = session.query(Key).filter_by(key_data=bytes(alice_verifying_key)).first() - if not alice_key_instance: - alice_key_instance = Key.from_umbral_key(alice_verifying_key, is_signing=True) - new_policy_arrangement = PolicyArrangement( expiration=expiration, id=arrangement_id, kfrag=kfrag, - alice_verifying_key=alice_key_instance, + alice_verifying_key=bytes(alice_verifying_key), alice_signature=None, # bob_verifying_key.id # TODO: Is this needed? ) @@ -180,7 +124,7 @@ class Datastore: if policy_arrangement is None: raise NotFound("Can't attach a kfrag to non-existent Arrangement {}".format(id_as_hex)) - if policy_arrangement.alice_verifying_key.key_data != alice.stamp: + if policy_arrangement.alice_verifying_key != alice.stamp: raise alice.SuspiciousActivity policy_arrangement.kfrag = bytes(kfrag) @@ -225,13 +169,7 @@ class Datastore: """ session = session or self._session_on_init_thread - # Get or Create Bob Verifying Key - fingerprint = fingerprint_from_key(bob_verifying_key) - key = session.query(Key).filter_by(fingerprint=fingerprint).first() - if not key: - key = self.add_key(key=bob_verifying_key) - - new_workorder = Workorder(bob_verifying_key_id=key.id, + new_workorder = Workorder(bob_verifying_key=bytes(bob_verifying_key), bob_signature=bob_signature, arrangement_id=arrangement_id) @@ -254,16 +192,13 @@ class Datastore: workorders = query.all() # Return all records else: - # Return arrangement records if arrangement_id: workorders = query.filter_by(arrangement_id=arrangement_id) # Return records for Bob else: - fingerprint = fingerprint_from_key(bob_verifying_key) - key = session.query(Key).filter_by(fingerprint=fingerprint).first() - workorders = query.filter_by(bob_verifying_key_id=key.id) + workorders = query.filter_by(bob_verifying_key=bob_verifying_key) if not workorders: raise NotFound diff --git a/nucypher/datastore/db/__init__.py b/nucypher/datastore/db/__init__.py deleted file mode 100644 index b0ebc240d..000000000 --- a/nucypher/datastore/db/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -This file is part of nucypher. - -nucypher is free software: you can redistribute it and/or modify -it under the terms of the GNU Affero General Public License as published by -the Free Software Foundation, either version 3 of the License, or -(at your option) any later version. - -nucypher is distributed in the hope that it will be useful, -but WITHOUT ANY WARRANTY; without even the implied warranty of -MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -GNU Affero General Public License for more details. - -You should have received a copy of the GNU Affero General Public License -along with nucypher. If not, see . -""" -from sqlalchemy import event -from sqlalchemy.engine import Engine -from sqlalchemy.ext.declarative import declarative_base - -Base = declarative_base() - - -@event.listens_for(Engine, "connect") -def set_secure_delete_pragma(dbapi_connection, connection_record): - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA secure_delete=on") - cursor.close() diff --git a/nucypher/datastore/db/models.py b/nucypher/datastore/db/models.py deleted file mode 100644 index be67264b6..000000000 --- a/nucypher/datastore/db/models.py +++ /dev/null @@ -1,95 +0,0 @@ -""" -This file is part of nucypher. - -nucypher is free software: you can redistribute it and/or modify -it under the terms of the GNU Affero General Public License as published by -the Free Software Foundation, either version 3 of the License, or -(at your option) any later version. - -nucypher is distributed in the hope that it will be useful, -but WITHOUT ANY WARRANTY; without even the implied warranty of -MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -GNU Affero General Public License for more details. - -You should have received a copy of the GNU Affero General Public License -along with nucypher. If not, see . -""" -from datetime import datetime -from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, LargeBinary) -from sqlalchemy.orm import relationship - -from nucypher.crypto.utils import fingerprint_from_key -from nucypher.datastore.db import Base - - -class Key(Base): - __tablename__ = 'keys' - - id = Column(Integer, primary_key=True) - fingerprint = Column(LargeBinary, unique=True) - key_data = Column(LargeBinary, unique=True) - is_signing = Column(Boolean, unique=False) - created_at = Column(DateTime, default=datetime.utcnow) - - def __init__(self, fingerprint, key_data, is_signing) -> None: - self.fingerprint = fingerprint - self.key_data = key_data - self.is_signing = is_signing - - def __repr__(self): - return f'{self.__class__.__name__}(id={self.id})' - - @classmethod - def from_umbral_key(cls, umbral_key, is_signing): - fingerprint = fingerprint_from_key(umbral_key) - key_data = bytes(umbral_key) - return cls(fingerprint, key_data, is_signing) - - -class PolicyArrangement(Base): - __tablename__ = 'policyarrangements' - - id = Column(LargeBinary, unique=True, primary_key=True) - expiration = Column(DateTime) - kfrag = Column(LargeBinary, unique=True, nullable=True) - alice_verifying_key_id = Column(Integer, ForeignKey('keys.id')) - alice_verifying_key = relationship(Key, backref="policies", lazy='joined') - - # TODO: Maybe this will be two signatures - one for the offer, one for the KFrag. - alice_signature = Column(LargeBinary, unique=True, nullable=True) - created_at = Column(DateTime, default=datetime.utcnow) - - def __init__(self, - expiration, - id, - kfrag=None, - alice_verifying_key=None, - alice_signature=None - ) -> None: - - self.expiration = expiration - self.id = id - self.kfrag = kfrag - self.alice_verifying_key = alice_verifying_key - self.alice_signature = alice_signature - - def __repr__(self): - return f'{self.__class__.__name__}(id={self.id})' - - -class Workorder(Base): - __tablename__ = 'workorders' - - id = Column(Integer, primary_key=True) - bob_verifying_key_id = Column(Integer, ForeignKey('keys.id')) - bob_signature = Column(LargeBinary, unique=True) - arrangement_id = Column(LargeBinary, unique=False) - created_at = Column(DateTime, default=datetime.utcnow) - - def __init__(self, bob_verifying_key_id, bob_signature, arrangement_id) -> None: - self.bob_verifying_key_id = bob_verifying_key_id - self.bob_signature = bob_signature - self.arrangement_id = arrangement_id - - def __repr__(self): - return f'{self.__class__.__name__}(id={self.id})' diff --git a/nucypher/datastore/models.py b/nucypher/datastore/models.py new file mode 100644 index 000000000..88982eccc --- /dev/null +++ b/nucypher/datastore/models.py @@ -0,0 +1,45 @@ +""" +This file is part of nucypher. + +nucypher is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +nucypher is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with nucypher. If not, see . +""" +from maya import MayaDT +from umbral.keys import UmbralPublicKey +from umbral.kfrags import KFrag + +from nucypher.crypto.signing import Signature +from nucypher.datastore.base import DatastoreRecord, RecordField + + +class PolicyArrangement(DatastoreRecord): + arrangement_id = RecordField(bytes) + expiration = RecordField(MayaDT, + encode=lambda maya_date: maya_date.iso8601().encode(), + decode=lambda maya_bytes: MayaDT.from_iso8601(maya_bytes.decode())) + kfrag = RecordField(KFrag, + encode=lambda kfrag: kfrag.to_bytes(), + decode=KFrag.from_bytes) + alice_verifying_key = RecordField(UmbralPublicKey, + encode=bytes, + decode=UmbralPublicKey.from_bytes) + + +class Workorder(DatastoreRecord): + arrangement_id = RecordField(bytes) + bob_verifying_key = RecordField(UmbralPublicKey, + encode=bytes, + decode=UmbralPublicKey.from_bytes) + bob_signature = RecordField(Signature, + encode=bytes, + decode=Signature.from_bytes) diff --git a/nucypher/network/server.py b/nucypher/network/server.py index 0bd55c829..641d9a9e6 100644 --- a/nucypher/network/server.py +++ b/nucypher/network/server.py @@ -319,7 +319,7 @@ def make_rest_app( policy_arrangement = datastore.get_policy_arrangement( id_as_hex.encode(), session=session) alice_pubkey = UmbralPublicKey.from_bytes( - policy_arrangement.alice_verifying_key.key_data) + policy_arrangement.alice_verifying_key) # Check that the request is the same for the provided revocation if id_as_hex != revocation.arrangement_id.hex(): @@ -355,8 +355,7 @@ def make_rest_app( # Get Work Order from nucypher.policy.collections import WorkOrder # Avoid circular import - alice_verifying_key_bytes = arrangement.alice_verifying_key.key_data - alice_verifying_key = UmbralPublicKey.from_bytes(alice_verifying_key_bytes) + alice_verifying_key = UmbralPublicKey.from_bytes(arrangement.alice_verifying_key) alice_address = canonical_address_from_umbral_key(alice_verifying_key) work_order_payload = request.data work_order = WorkOrder.from_rest_payload(arrangement_id=arrangement_id, diff --git a/tests/fixtures.py b/tests/fixtures.py index a3a69fc93..514e3d6f3 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -65,7 +65,6 @@ from nucypher.config.constants import TEMPORARY_DOMAIN from nucypher.crypto.powers import TransactingPower from nucypher.crypto.utils import canonical_address_from_umbral_key from nucypher.datastore import datastore -from nucypher.datastore.db import Base from nucypher.policy.collections import IndisputableEvidence, WorkOrder from nucypher.utilities.logging import GlobalLoggerSettings, Logger @@ -134,12 +133,12 @@ def temp_dir_path(): temp_dir.cleanup() -@pytest.fixture(scope="module") -def test_datastore(): - engine = create_engine('sqlite:///:memory:') - Base.metadata.create_all(engine) - test_datastore = datastore.Datastore(engine) - yield test_datastore +#@pytest.fixture(scope="module") +#def test_datastore(): +# engine = create_engine('sqlite:///:memory:') +# Base.metadata.create_all(engine) +# test_datastore = datastore.Datastore(engine) +# yield test_datastore @pytest.fixture(scope='function') diff --git a/tests/integration/datastore/test_datastore.py b/tests/integration/datastore/test_datastore.py index ec171b02c..91d5d89d7 100644 --- a/tests/integration/datastore/test_datastore.py +++ b/tests/integration/datastore/test_datastore.py @@ -14,64 +14,104 @@ GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with nucypher. If not, see . """ +import lmdb +import maya +import msgpack import pytest +import tempfile from datetime import datetime from nucypher.datastore import datastore, keypairs +from nucypher.datastore.base import DatastoreRecord, RecordField +from nucypher.datastore.models import PolicyArrangement, Workorder -@pytest.mark.usefixtures('testerchain') -def test_key_sqlite_datastore(test_datastore, federated_bob): +def test_datastore_record_read(): + class TestRecord(DatastoreRecord): + _test = RecordField(bytes) + _test_date = RecordField(datetime, + encode=lambda val: datetime.isoformat(val).encode(), + decode=lambda val: datetime.fromisoformat(val.decode())) - # Test add pubkey - test_datastore.add_key(federated_bob.stamp, is_signing=True) + db_env = lmdb.open(tempfile.mkdtemp()) + with db_env.begin() as db_tx: + # Check the default attrs. + test_rec = TestRecord(db_tx, 'testing', writeable=False) + assert test_rec._record_id == 'testing' + assert test_rec._DatastoreRecord__db_tx == db_tx + assert test_rec._DatastoreRecord__writeable == False + assert test_rec._DatastoreRecord__storagekey == 'TestRecord:{record_field}:{record_id}' - # Test get pubkey - query_key = test_datastore.get_key(federated_bob.stamp.fingerprint()) - assert bytes(federated_bob.stamp) == bytes(query_key) + # Reading an attr with no RecordField should error + with pytest.raises(TypeError): + should_error = test_rec.nonexistant_field - # Test del pubkey - test_datastore.del_key(federated_bob.stamp.fingerprint()) - with pytest.raises(datastore.NotFound): - del_key = test_datastore.get_key(federated_bob.stamp.fingerprint()) + # Reading when no records exist errors + with pytest.raises(AttributeError): + should_error = test_rec.test + + # The record is not writeable + with pytest.raises(TypeError): + test_rec.test = b'should error' -def test_policy_arrangement_sqlite_datastore(test_datastore): - alice_keypair_sig = keypairs.SigningKeypair(generate_keys_if_needed=True) +def test_datastore_record_write(): + class TestRecord(DatastoreRecord): + _test = RecordField(bytes) + _test_date = RecordField(datetime, + encode=lambda val: datetime.isoformat(val).encode(), + decode=lambda val: datetime.fromisoformat(val.decode())) - arrangement_id = b'test' + # Test writing + db_env = lmdb.open(tempfile.mkdtemp()) + with db_env.begin(write=True) as db_tx: + test_rec = TestRecord(db_tx, 'testing', writeable=True) + assert test_rec._DatastoreRecord__writeable == True - # Test add PolicyArrangement - new_arrangement = test_datastore.add_policy_arrangement( - datetime.utcnow(), b'test', arrangement_id, alice_verifying_key=alice_keypair_sig.pubkey, - alice_signature=b'test' - ) + # Write an invalid serialization of `test` and test retrieving it is + # a TypeError + db_tx.put(b'TestRecord:test:testing', msgpack.packb(1234)) + with pytest.raises(TypeError): + should_error = test_rec.test - # Test get PolicyArrangement - query_arrangement = test_datastore.get_policy_arrangement(arrangement_id) - assert new_arrangement == query_arrangement + # Writing an invalid serialization of a field is a `TypeError` + with pytest.raises(TypeError): + test_rec.test = 1234 - # Test del PolicyArrangement - test_datastore.del_policy_arrangement(arrangement_id) - with pytest.raises(datastore.NotFound): - del_key = test_datastore.get_policy_arrangement(arrangement_id) + # Test writing a valid field and getting it. + test_rec.test = b'good write' + assert test_rec.test == b'good write' + # TODO: Mock a `DBWriteError` -def test_workorder_sqlite_datastore(test_datastore): - bob_keypair_sig1 = keypairs.SigningKeypair(generate_keys_if_needed=True) - bob_keypair_sig2 = keypairs.SigningKeypair(generate_keys_if_needed=True) - - arrangement_id = b'test' - - # Test add workorder - new_workorder1 = test_datastore.save_workorder(bob_keypair_sig1.pubkey, b'test0', arrangement_id) - new_workorder2 = test_datastore.save_workorder(bob_keypair_sig2.pubkey, b'test1', arrangement_id) - - # Test get workorder - query_workorders = test_datastore.get_workorders(arrangement_id) - assert {new_workorder1, new_workorder2}.issubset(query_workorders) - - # Test del workorder - deleted = test_datastore.del_workorders(arrangement_id) - assert deleted > 0 - assert len(test_datastore.get_workorders(arrangement_id)) == 0 +# def test_datastore_policy_arrangement_model(): +# arrangement_id = b'test' +# expiration = maya.now() +# alice_verifying_key = keypairs.SigningKeypair(generate_keys_if_needed=True).pubkey +# +# # TODO: Leaving out KFrag for now since I don't have an easy way to grab one. +# test_record = PolicyArrangement(arrangement_id=arrangement_id, +# expiration=expiration, +# alice_verifying_key=alice_verifying_key) +# +# assert test_record.arrangement_id == arrangement_id +# assert test_record.expiration == expiration +# assert alice_verifying_key == alice_verifying_key +# assert test_record == PolicyArrangement.from_bytes(test_record.to_bytes()) +# +# +# def test_datastore_workorder_model(): +# bob_keypair = keypairs.SigningKeypair(generate_keys_if_needed=True) +# +# arrangement_id = b'test' +# bob_verifying_key = bob_keypair.pubkey +# bob_signature = bob_keypair.sign(b'test') +# +# test_record = Workorder(arrangement_id=arrangement_id, +# bob_verifying_key=bob_verifying_key, +# bob_signature=bob_signature) +# +# assert test_record.arrangement_id == arrangement_id +# assert test_record.bob_verifying_key == bob_verifying_key +# assert test_record.bob_signature == bob_signature +# assert test_record == Workorder.from_bytes(test_record.to_bytes())