Allow DKGStorage to return Option object/value or None.

pull/3183/head
derekpierre 2023-07-17 14:10:53 -04:00
parent 3d81a4e132
commit d944c0bbe6
1 changed files with 14 additions and 10 deletions

View File

@ -1,5 +1,5 @@
from collections import defaultdict from collections import defaultdict
from typing import Union from typing import Optional, Union
from hexbytes import HexBytes from hexbytes import HexBytes
from nucypher_core.ferveo import AggregatedTranscript, Transcript from nucypher_core.ferveo import AggregatedTranscript, Transcript
@ -15,7 +15,7 @@ class DKGStorage:
def store_transcript(self, ritual_id: int, transcript: Transcript) -> None: def store_transcript(self, ritual_id: int, transcript: Transcript) -> None:
self.data["transcripts"][ritual_id] = bytes(transcript) self.data["transcripts"][ritual_id] = bytes(transcript)
def get_transcript(self, ritual_id: int) -> Transcript: def get_transcript(self, ritual_id: int) -> Optional[Transcript]:
data = self.data["transcripts"][ritual_id] data = self.data["transcripts"][ritual_id]
transcript = Transcript.from_bytes(data) transcript = Transcript.from_bytes(data)
return transcript return transcript
@ -25,14 +25,18 @@ class DKGStorage:
) -> None: ) -> None:
self.data["transcript_receipts"][ritual_id] = txhash_or_receipt self.data["transcript_receipts"][ritual_id] = txhash_or_receipt
def get_transcript_receipt(self, ritual_id: int) -> Union[TxReceipt, HexBytes]: def get_transcript_receipt(
return self.data["transcript_receipts"][ritual_id] self, ritual_id: int
) -> Optional[Union[TxReceipt, HexBytes]]:
return self.data["transcript_receipts"].get(ritual_id)
def store_aggregated_transcript(self, ritual_id: int, aggregated_transcript: AggregatedTranscript) -> None: def store_aggregated_transcript(self, ritual_id: int, aggregated_transcript: AggregatedTranscript) -> None:
self.data["aggregated_transcripts"][ritual_id] = bytes(aggregated_transcript) self.data["aggregated_transcripts"][ritual_id] = bytes(aggregated_transcript)
def get_aggregated_transcript(self, ritual_id: int) -> AggregatedTranscript: def get_aggregated_transcript(
return self.data["aggregated_transcripts"][ritual_id] self, ritual_id: int
) -> Optional[AggregatedTranscript]:
return self.data["aggregated_transcripts"].get(ritual_id)
def store_aggregated_transcript_receipt( def store_aggregated_transcript_receipt(
self, ritual_id: int, txhash_or_receipt: Union[TxReceipt, HexBytes] self, ritual_id: int, txhash_or_receipt: Union[TxReceipt, HexBytes]
@ -41,11 +45,11 @@ class DKGStorage:
def get_aggregated_transcript_receipt( def get_aggregated_transcript_receipt(
self, ritual_id: int self, ritual_id: int
) -> Union[TxReceipt, HexBytes]: ) -> Optional[Union[TxReceipt, HexBytes]]:
return self.data["aggregated_transcript_receipts"][ritual_id] return self.data["aggregated_transcript_receipts"].get(ritual_id)
def store_public_key(self, ritual_id: int, public_key: bytes) -> None: def store_public_key(self, ritual_id: int, public_key: bytes) -> None:
self.data["public_keys"][ritual_id] = public_key self.data["public_keys"][ritual_id] = public_key
def get_public_key(self, ritual_id: int) -> bytes: def get_public_key(self, ritual_id: int) -> Optional[bytes]:
return self.data["public_keys"][ritual_id] return self.data["public_keys"].get(ritual_id)