Create a base 'DatastoreRecord' class and models with tests

Move the base models to base.py and add more datastore tests
Change key structure to 'RecordType:RecordField:RecordID'
Eliminate the db module
pull/2099/head
tuxxy 2020-06-17 15:07:29 +02:00
parent 4026de543e
commit 184b3bf563
8 changed files with 284 additions and 248 deletions

141
nucypher/datastore/base.py Normal file
View File

@ -0,0 +1,141 @@
"""
This file is part of nucypher.
nucypher is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
nucypher is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
import msgpack
from typing import Any, Callable, Iterable, NamedTuple, Optional, Union
class DBWriteError(Exception):
"""
Exception class for when db writes fail.
"""
pass
class RecordField(NamedTuple):
field_type: Any
encode: Optional[Callable] = lambda field: field
decode: Optional[Callable] = lambda field: field
class DatastoreRecord:
def __new__(cls, *args, **kwargs):
# Set default class attributes for the new instance
cls.__writeable = None
cls.__storagekey = f'{cls.__name__}:{{record_field}}:{{record_id}}'
return super().__new__(cls)
def __init__(self,
db_tx: 'lmdb.Transaction',
record_id: Union[int, str],
writeable: bool = False) -> None:
self._record_id = record_id
self.__db_tx = db_tx
self.__writeable = writeable
def __setattr__(self, attr: str, value: Any) -> None:
"""
This method is called when setting attributes on the class. We override
this method to serialize the value being set to the attribute, and then
we _write_ it to the database.
When `__writeable` is `None`, we only set attributes on the instance.
When `__writeable` is `False`, we raise a `TypeError`.
Finally, when `__writeable` is `True`, we get the `RecordField` for
the corresponding `attr` and check that the `value` being set is
the correct type via its `RecordField.field_type`. If the type is not
correct, we raise a `TypeError`.
If the type is correct, we then serialize it to bytes via its
`RecordField.encode` function and pack it with msgpack. Then the value
gets written to the database. If the value is unable to be written,
this will raise a `DBWriteError`.
"""
# When writeable is None (meaning, it hasn't been __init__ yet), then
# we allow any attribute to be set on the instance.
if self.__writeable is None:
super().__setattr__(attr, value)
# Datastore records are not writeable/mutable by default, so we
# raise a TypeError in the event that writeable is False.
elif self.__writeable is False:
raise TypeError("This datastore record isn't writeable.")
# A datastore record is only mutated iff writeable is True.
elif self.__writeable is True:
record_field = self.__get_record_field(attr)
if not type(value) == record_field.field_type:
raise TypeError(f'Given record is type {type(value)}; expected {record_field.field_type}')
field_value = msgpack.packb(record_field.encode(value))
self.__write_raw_record(attr, field_value)
def __getattr__(self, attr: str) -> Any:
"""
This method is called when accessing attributes that don't exist on the
class. We override this method to _read_ from the database and return
a deserialized record.
We deserialize records by calling the record's respective `RecordField.decode`
function. If the deserialized type doesn't match the type defined by
its `RecordField.field_type`, then this method will raise a `TypeError`.
"""
# Handle __getattr__ look ups for private fields
if attr.startswith('_'):
return super().__getattr__(attr)
# Get the corresponding RecordField and retrieve the raw value from
# the db, unpack it, then use the `RecordField` to deserialize it.
record_field = self.__get_record_field(attr)
field_value = record_field.decode(msgpack.unpackb(self.__retrieve_raw_record(attr)))
if not type(field_value) == record_field.field_type:
raise TypeError(f"Decoded record was type {type(field_value)}; expected {record_field.field_type}")
return field_value
def __retrieve_raw_record(self, record_field: str) -> bytes:
"""
Retrieves a raw record, as bytes, from the database given a `record_field`.
If the record doesn't exist, this method raises an `AttributeError`.
"""
key = self.__storagekey.format(record_field=record_field, record_id=self._record_id).encode()
field_value = self.__db_tx.get(key, default=None)
if field_value is None:
raise AttributeError(f"No {record_field} record found for ID: {self._record_id}.")
return field_value
def __write_raw_record(self, record_field: str, value: bytes) -> None:
"""
Writes a raw record, as bytes, to the database given a `record_field`
and a `value`.
If the record is unable to be written, this method raises a `DBWriteError`.
"""
key = self.__storagekey.format(record_field=record_field, record_id=self._record_id).encode()
if not self.__db_tx.put(key, value, overwrite=True):
raise DBWriteError("Couldn't write the record to the database.")
def __get_record_field(self, attr: str) -> 'RecordField':
"""
Uses `getattr` to return the `RecordField` object for a given
attribute.
These objects are accessed via class attrs as `_{attribute}`. If the
`RecordField` doesn't exist for a given `attr`, then this method will
raise a `TypeError`.
"""
try:
record_field = getattr(self, f'_{attr}')
except AttributeError:
raise TypeError(f'No valid RecordField found on {self} for {attr}.')
return record_field

View File

@ -19,15 +19,14 @@ along with nucypher. If not, see <https://www.gnu.org/licenses/>.
import maya
from bytestring_splitter import BytestringSplitter
from datetime import datetime
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
from typing import List
from umbral.keys import UmbralPublicKey
from umbral.kfrags import KFrag
from nucypher.crypto.signing import Signature
from nucypher.crypto.utils import fingerprint_from_key
from nucypher.datastore.db.models import Key, PolicyArrangement, Workorder
from nucypher.datastore.db.base import DatastoreRecord, RecordField
from nucypher.datastore.db.models import PolicyArrangement, Workorder
class NotFound(Exception):
@ -64,57 +63,6 @@ class Datastore:
session.rollback()
raise
#
# Keys
#
def add_key(self,
key: UmbralPublicKey,
is_signing: bool = True,
session=None
) -> Key:
"""
:param key: Keypair object to store in the keystore.
:return: The newly added key object.
"""
session = session or self._session_on_init_thread
fingerprint = fingerprint_from_key(key)
key_data = bytes(key)
new_key = Key(fingerprint, key_data, is_signing)
session.add(new_key)
self.__commit(session=session)
return new_key
def get_key(self, fingerprint: bytes, session=None) -> UmbralPublicKey:
"""
Returns a key from the Datastore.
:param fingerprint: Fingerprint, in bytes, of key to return
:return: Keypair of the returned key.
"""
session = session or self._session_on_init_thread
key = session.query(Key).filter_by(fingerprint=fingerprint).first()
if not key:
raise NotFound("No key with fingerprint {} found.".format(fingerprint))
pubkey = UmbralPublicKey.from_bytes(key.key_data)
return pubkey
def del_key(self, fingerprint: bytes, session=None):
"""
Deletes a key from the Datastore.
:param fingerprint: Fingerprint of key to delete
"""
session = session or self._session_on_init_thread
session.query(Key).filter_by(fingerprint=fingerprint).delete()
self.__commit(session=session)
#
# Arrangements
#
@ -134,15 +82,11 @@ class Datastore:
"""
session = session or self._session_on_init_thread
alice_key_instance = session.query(Key).filter_by(key_data=bytes(alice_verifying_key)).first()
if not alice_key_instance:
alice_key_instance = Key.from_umbral_key(alice_verifying_key, is_signing=True)
new_policy_arrangement = PolicyArrangement(
expiration=expiration,
id=arrangement_id,
kfrag=kfrag,
alice_verifying_key=alice_key_instance,
alice_verifying_key=bytes(alice_verifying_key),
alice_signature=None,
# bob_verifying_key.id # TODO: Is this needed?
)
@ -180,7 +124,7 @@ class Datastore:
if policy_arrangement is None:
raise NotFound("Can't attach a kfrag to non-existent Arrangement {}".format(id_as_hex))
if policy_arrangement.alice_verifying_key.key_data != alice.stamp:
if policy_arrangement.alice_verifying_key != alice.stamp:
raise alice.SuspiciousActivity
policy_arrangement.kfrag = bytes(kfrag)
@ -225,13 +169,7 @@ class Datastore:
"""
session = session or self._session_on_init_thread
# Get or Create Bob Verifying Key
fingerprint = fingerprint_from_key(bob_verifying_key)
key = session.query(Key).filter_by(fingerprint=fingerprint).first()
if not key:
key = self.add_key(key=bob_verifying_key)
new_workorder = Workorder(bob_verifying_key_id=key.id,
new_workorder = Workorder(bob_verifying_key=bytes(bob_verifying_key),
bob_signature=bob_signature,
arrangement_id=arrangement_id)
@ -254,16 +192,13 @@ class Datastore:
workorders = query.all() # Return all records
else:
# Return arrangement records
if arrangement_id:
workorders = query.filter_by(arrangement_id=arrangement_id)
# Return records for Bob
else:
fingerprint = fingerprint_from_key(bob_verifying_key)
key = session.query(Key).filter_by(fingerprint=fingerprint).first()
workorders = query.filter_by(bob_verifying_key_id=key.id)
workorders = query.filter_by(bob_verifying_key=bob_verifying_key)
if not workorders:
raise NotFound

View File

@ -1,28 +0,0 @@
"""
This file is part of nucypher.
nucypher is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
nucypher is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
from sqlalchemy import event
from sqlalchemy.engine import Engine
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
@event.listens_for(Engine, "connect")
def set_secure_delete_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA secure_delete=on")
cursor.close()

View File

@ -1,95 +0,0 @@
"""
This file is part of nucypher.
nucypher is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
nucypher is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
from datetime import datetime
from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, LargeBinary)
from sqlalchemy.orm import relationship
from nucypher.crypto.utils import fingerprint_from_key
from nucypher.datastore.db import Base
class Key(Base):
__tablename__ = 'keys'
id = Column(Integer, primary_key=True)
fingerprint = Column(LargeBinary, unique=True)
key_data = Column(LargeBinary, unique=True)
is_signing = Column(Boolean, unique=False)
created_at = Column(DateTime, default=datetime.utcnow)
def __init__(self, fingerprint, key_data, is_signing) -> None:
self.fingerprint = fingerprint
self.key_data = key_data
self.is_signing = is_signing
def __repr__(self):
return f'{self.__class__.__name__}(id={self.id})'
@classmethod
def from_umbral_key(cls, umbral_key, is_signing):
fingerprint = fingerprint_from_key(umbral_key)
key_data = bytes(umbral_key)
return cls(fingerprint, key_data, is_signing)
class PolicyArrangement(Base):
__tablename__ = 'policyarrangements'
id = Column(LargeBinary, unique=True, primary_key=True)
expiration = Column(DateTime)
kfrag = Column(LargeBinary, unique=True, nullable=True)
alice_verifying_key_id = Column(Integer, ForeignKey('keys.id'))
alice_verifying_key = relationship(Key, backref="policies", lazy='joined')
# TODO: Maybe this will be two signatures - one for the offer, one for the KFrag.
alice_signature = Column(LargeBinary, unique=True, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
def __init__(self,
expiration,
id,
kfrag=None,
alice_verifying_key=None,
alice_signature=None
) -> None:
self.expiration = expiration
self.id = id
self.kfrag = kfrag
self.alice_verifying_key = alice_verifying_key
self.alice_signature = alice_signature
def __repr__(self):
return f'{self.__class__.__name__}(id={self.id})'
class Workorder(Base):
__tablename__ = 'workorders'
id = Column(Integer, primary_key=True)
bob_verifying_key_id = Column(Integer, ForeignKey('keys.id'))
bob_signature = Column(LargeBinary, unique=True)
arrangement_id = Column(LargeBinary, unique=False)
created_at = Column(DateTime, default=datetime.utcnow)
def __init__(self, bob_verifying_key_id, bob_signature, arrangement_id) -> None:
self.bob_verifying_key_id = bob_verifying_key_id
self.bob_signature = bob_signature
self.arrangement_id = arrangement_id
def __repr__(self):
return f'{self.__class__.__name__}(id={self.id})'

View File

@ -0,0 +1,45 @@
"""
This file is part of nucypher.
nucypher is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
nucypher is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
from maya import MayaDT
from umbral.keys import UmbralPublicKey
from umbral.kfrags import KFrag
from nucypher.crypto.signing import Signature
from nucypher.datastore.base import DatastoreRecord, RecordField
class PolicyArrangement(DatastoreRecord):
arrangement_id = RecordField(bytes)
expiration = RecordField(MayaDT,
encode=lambda maya_date: maya_date.iso8601().encode(),
decode=lambda maya_bytes: MayaDT.from_iso8601(maya_bytes.decode()))
kfrag = RecordField(KFrag,
encode=lambda kfrag: kfrag.to_bytes(),
decode=KFrag.from_bytes)
alice_verifying_key = RecordField(UmbralPublicKey,
encode=bytes,
decode=UmbralPublicKey.from_bytes)
class Workorder(DatastoreRecord):
arrangement_id = RecordField(bytes)
bob_verifying_key = RecordField(UmbralPublicKey,
encode=bytes,
decode=UmbralPublicKey.from_bytes)
bob_signature = RecordField(Signature,
encode=bytes,
decode=Signature.from_bytes)

View File

@ -319,7 +319,7 @@ def make_rest_app(
policy_arrangement = datastore.get_policy_arrangement(
id_as_hex.encode(), session=session)
alice_pubkey = UmbralPublicKey.from_bytes(
policy_arrangement.alice_verifying_key.key_data)
policy_arrangement.alice_verifying_key)
# Check that the request is the same for the provided revocation
if id_as_hex != revocation.arrangement_id.hex():
@ -355,8 +355,7 @@ def make_rest_app(
# Get Work Order
from nucypher.policy.collections import WorkOrder # Avoid circular import
alice_verifying_key_bytes = arrangement.alice_verifying_key.key_data
alice_verifying_key = UmbralPublicKey.from_bytes(alice_verifying_key_bytes)
alice_verifying_key = UmbralPublicKey.from_bytes(arrangement.alice_verifying_key)
alice_address = canonical_address_from_umbral_key(alice_verifying_key)
work_order_payload = request.data
work_order = WorkOrder.from_rest_payload(arrangement_id=arrangement_id,

View File

@ -65,7 +65,6 @@ from nucypher.config.constants import TEMPORARY_DOMAIN
from nucypher.crypto.powers import TransactingPower
from nucypher.crypto.utils import canonical_address_from_umbral_key
from nucypher.datastore import datastore
from nucypher.datastore.db import Base
from nucypher.policy.collections import IndisputableEvidence, WorkOrder
from nucypher.utilities.logging import GlobalLoggerSettings, Logger
@ -134,12 +133,12 @@ def temp_dir_path():
temp_dir.cleanup()
@pytest.fixture(scope="module")
def test_datastore():
engine = create_engine('sqlite:///:memory:')
Base.metadata.create_all(engine)
test_datastore = datastore.Datastore(engine)
yield test_datastore
#@pytest.fixture(scope="module")
#def test_datastore():
# engine = create_engine('sqlite:///:memory:')
# Base.metadata.create_all(engine)
# test_datastore = datastore.Datastore(engine)
# yield test_datastore
@pytest.fixture(scope='function')

View File

@ -14,64 +14,104 @@ GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
import lmdb
import maya
import msgpack
import pytest
import tempfile
from datetime import datetime
from nucypher.datastore import datastore, keypairs
from nucypher.datastore.base import DatastoreRecord, RecordField
from nucypher.datastore.models import PolicyArrangement, Workorder
@pytest.mark.usefixtures('testerchain')
def test_key_sqlite_datastore(test_datastore, federated_bob):
def test_datastore_record_read():
class TestRecord(DatastoreRecord):
_test = RecordField(bytes)
_test_date = RecordField(datetime,
encode=lambda val: datetime.isoformat(val).encode(),
decode=lambda val: datetime.fromisoformat(val.decode()))
# Test add pubkey
test_datastore.add_key(federated_bob.stamp, is_signing=True)
db_env = lmdb.open(tempfile.mkdtemp())
with db_env.begin() as db_tx:
# Check the default attrs.
test_rec = TestRecord(db_tx, 'testing', writeable=False)
assert test_rec._record_id == 'testing'
assert test_rec._DatastoreRecord__db_tx == db_tx
assert test_rec._DatastoreRecord__writeable == False
assert test_rec._DatastoreRecord__storagekey == 'TestRecord:{record_field}:{record_id}'
# Test get pubkey
query_key = test_datastore.get_key(federated_bob.stamp.fingerprint())
assert bytes(federated_bob.stamp) == bytes(query_key)
# Reading an attr with no RecordField should error
with pytest.raises(TypeError):
should_error = test_rec.nonexistant_field
# Test del pubkey
test_datastore.del_key(federated_bob.stamp.fingerprint())
with pytest.raises(datastore.NotFound):
del_key = test_datastore.get_key(federated_bob.stamp.fingerprint())
# Reading when no records exist errors
with pytest.raises(AttributeError):
should_error = test_rec.test
# The record is not writeable
with pytest.raises(TypeError):
test_rec.test = b'should error'
def test_policy_arrangement_sqlite_datastore(test_datastore):
alice_keypair_sig = keypairs.SigningKeypair(generate_keys_if_needed=True)
def test_datastore_record_write():
class TestRecord(DatastoreRecord):
_test = RecordField(bytes)
_test_date = RecordField(datetime,
encode=lambda val: datetime.isoformat(val).encode(),
decode=lambda val: datetime.fromisoformat(val.decode()))
arrangement_id = b'test'
# Test writing
db_env = lmdb.open(tempfile.mkdtemp())
with db_env.begin(write=True) as db_tx:
test_rec = TestRecord(db_tx, 'testing', writeable=True)
assert test_rec._DatastoreRecord__writeable == True
# Test add PolicyArrangement
new_arrangement = test_datastore.add_policy_arrangement(
datetime.utcnow(), b'test', arrangement_id, alice_verifying_key=alice_keypair_sig.pubkey,
alice_signature=b'test'
)
# Write an invalid serialization of `test` and test retrieving it is
# a TypeError
db_tx.put(b'TestRecord:test:testing', msgpack.packb(1234))
with pytest.raises(TypeError):
should_error = test_rec.test
# Test get PolicyArrangement
query_arrangement = test_datastore.get_policy_arrangement(arrangement_id)
assert new_arrangement == query_arrangement
# Writing an invalid serialization of a field is a `TypeError`
with pytest.raises(TypeError):
test_rec.test = 1234
# Test del PolicyArrangement
test_datastore.del_policy_arrangement(arrangement_id)
with pytest.raises(datastore.NotFound):
del_key = test_datastore.get_policy_arrangement(arrangement_id)
# Test writing a valid field and getting it.
test_rec.test = b'good write'
assert test_rec.test == b'good write'
# TODO: Mock a `DBWriteError`
def test_workorder_sqlite_datastore(test_datastore):
bob_keypair_sig1 = keypairs.SigningKeypair(generate_keys_if_needed=True)
bob_keypair_sig2 = keypairs.SigningKeypair(generate_keys_if_needed=True)
arrangement_id = b'test'
# Test add workorder
new_workorder1 = test_datastore.save_workorder(bob_keypair_sig1.pubkey, b'test0', arrangement_id)
new_workorder2 = test_datastore.save_workorder(bob_keypair_sig2.pubkey, b'test1', arrangement_id)
# Test get workorder
query_workorders = test_datastore.get_workorders(arrangement_id)
assert {new_workorder1, new_workorder2}.issubset(query_workorders)
# Test del workorder
deleted = test_datastore.del_workorders(arrangement_id)
assert deleted > 0
assert len(test_datastore.get_workorders(arrangement_id)) == 0
# def test_datastore_policy_arrangement_model():
# arrangement_id = b'test'
# expiration = maya.now()
# alice_verifying_key = keypairs.SigningKeypair(generate_keys_if_needed=True).pubkey
#
# # TODO: Leaving out KFrag for now since I don't have an easy way to grab one.
# test_record = PolicyArrangement(arrangement_id=arrangement_id,
# expiration=expiration,
# alice_verifying_key=alice_verifying_key)
#
# assert test_record.arrangement_id == arrangement_id
# assert test_record.expiration == expiration
# assert alice_verifying_key == alice_verifying_key
# assert test_record == PolicyArrangement.from_bytes(test_record.to_bytes())
#
#
# def test_datastore_workorder_model():
# bob_keypair = keypairs.SigningKeypair(generate_keys_if_needed=True)
#
# arrangement_id = b'test'
# bob_verifying_key = bob_keypair.pubkey
# bob_signature = bob_keypair.sign(b'test')
#
# test_record = Workorder(arrangement_id=arrangement_id,
# bob_verifying_key=bob_verifying_key,
# bob_signature=bob_signature)
#
# assert test_record.arrangement_id == arrangement_id
# assert test_record.bob_verifying_key == bob_verifying_key
# assert test_record.bob_signature == bob_signature
# assert test_record == Workorder.from_bytes(test_record.to_bytes())