From bc77489e942af97ff463c9c85a0db3add4dad1b7 Mon Sep 17 00:00:00 2001 From: Kieran Prasch Date: Fri, 10 Feb 2023 12:11:24 -0800 Subject: [PATCH] CoordinatorV3 adaptation and integration with ferveo-server crate draft. --- nucypher/blockchain/eth/actors.py | 171 ++++++++++++++++++------ nucypher/blockchain/eth/agents.py | 133 +++++++++++------- nucypher/blockchain/eth/trackers/dkg.py | 29 ++-- nucypher/crypto/dkg.py | 68 ++++++++++ tests/mock/ferveo.py | 34 ----- 5 files changed, 299 insertions(+), 136 deletions(-) create mode 100644 nucypher/crypto/dkg.py delete mode 100644 tests/mock/ferveo.py diff --git a/nucypher/blockchain/eth/actors.py b/nucypher/blockchain/eth/actors.py index e455d5602..b1e1cecea 100644 --- a/nucypher/blockchain/eth/actors.py +++ b/nucypher/blockchain/eth/actors.py @@ -1,9 +1,9 @@ import json +import time from decimal import Decimal -from typing import Optional, Tuple, Union, Dict, List +from typing import Optional, Tuple, Union import maya -import time from constant_sorrow.constants import FULL from eth_typing import ChecksumAddress from hexbytes import HexBytes @@ -15,8 +15,9 @@ from nucypher.blockchain.economics import Economics from nucypher.blockchain.eth.agents import ( AdjudicatorAgent, ContractAgency, + CoordinatorAgent, NucypherTokenAgent, - PREApplicationAgent, CoordinatorAgent, + PREApplicationAgent, ) from nucypher.blockchain.eth.constants import NULL_ADDRESS from nucypher.blockchain.eth.decorators import save_receipt, validate_checksum_address @@ -34,12 +35,13 @@ from nucypher.blockchain.eth.token import NU from nucypher.blockchain.eth.trackers.dkg import RitualTracker from nucypher.blockchain.eth.trackers.pre import WorkTracker from nucypher.config.constants import DEFAULT_CONFIG_ROOT +from nucypher.crypto import dkg from nucypher.crypto.powers import CryptoPower, TransactingPower from nucypher.network.trackers import OperatorBondedTracker +from nucypher.policy.conditions.lingo import ConditionLingo from nucypher.policy.payment import ContractPayment from nucypher.utilities.emitters import StdoutEmitter from nucypher.utilities.logging import Logger -from tests.mock.ferveo import generate_dkg_transcript, confirm_dkg_transcript class BaseActor: @@ -445,54 +447,145 @@ class Ritualist(BaseActor): contract=self.coordinator_agent.contract ) - def handle_start_ritual(self, ritual_id: int, timestamp: int, nodes: List[ChecksumAddress], *args, **kwargs): - """Check in with the coordinator.""" - # from the tracker's internal cache - ritual = self.ritual_tracker.rituals[ritual_id] - node_index = nodes.index(self.transacting_power.account) - if ritual.performances[node_index].checkin_timestamp != 0: - raise self.RitualError(f"Node {self.transacting_power.account} has already checked in for ritual {ritual_id}") - receipt = self.coordinator_agent.checkin( - ritual_id=ritual_id, - node_index=node_index, - transacting_power=self.transacting_power - ) - return receipt + self.dkg_storage = {"transcripts": {}, "aggregated_transcripts": {}} - def handle_start_transcript_round(self, ritual_id: int, timestamp: int, *args, **kwargs): - """Post a DKG transcript to the blockchain.""" - # from the tracker's internal cache - ritual = self.ritual_tracker.rituals[ritual_id] - node_index = self.ritual_tracker.get_node_index(ritual_id=ritual_id, node=self.transacting_power.account) + def get_ritual(self, ritual_id: int) -> CoordinatorAgent.Ritual: + try: + ritual = self.ritual_tracker.rituals[ritual_id] + except KeyError: + raise self.ActorError(f"{ritual_id} is not in the local cache") + return ritual + + def store_transcript(self, ritual_id: int, transcript: Transcript) -> None: + self.dkg_storage["transcripts"][ritual_id] = bytes(transcript) + + def store_aggregated_transcript( + self, ritual_id: int, aggregated_transcript: AggregatedTranscript + ) -> None: + self.dkg_storage["aggregated_transcripts"][ritual_id] = bytes( + aggregated_transcript + ) + + def get_aggregated_transcript(self, ritual_id: int) -> AggregatedTranscript: + data = self.dkg_storage["aggregated_transcripts"][ritual_id] + aggregated_transcript = AggregatedTranscript.from_bytes(data) + return aggregated_transcript + + def get_transcript(self, ritual_id: int) -> AggregatedTranscript: + data = self.dkg_storage["transcripts"][ritual_id] + transcript = Transcript.from_bytes(data) + return transcript + + def perform_round_1(self, ritual_id: int, timestamp: int, *args, **kwargs): + ritual = self.get_ritual(ritual_id) + if ritual.status != CoordinatorAgent.Ritual.Status.WAITING_FOR_CONFIRMATIONS: + raise self.ActorError( + f"ritual #{ritual.id} is not waiting for transcripts." + ) + node_index = self.ritual_tracker.get_node_index( + ritual_id=ritual_id, node=self.transacting_power.account + ) if ritual.performances[node_index].transcript: - raise self.RitualError(f"Node {self.transacting_power.account} has already posted a transcript for ritual {ritual_id}") - transcript = generate_dkg_transcript() + raise self.RitualError( + f"Node {self.transacting_power.account} has already posted a transcript for ritual {ritual_id}" + ) + self.log.debug( + f"performing round 1 of DKG ritual #{ritual_id} from blocktime {timestamp}" + ) + + try: + transcript = dkg.generate_transcript( + ritual_id=ritual_id, + checksum_address=self.checksum_address, + shares=len(ritual.nodes), + threshold=ritual.threshold, + nodes=ritual.nodes, + ) # TODO: Error handling + except FerveroError: + raise self.ActorError( + f"error generating DKG transcript for ritual #{ritual_id}" + ) + + self.store_transcript(ritual_id=ritual_id, transcript=transcript) + receipt = self.coordinator_agent.post_transcript( - node_index=self.ritual_tracker.get_node_index(ritual_id=ritual_id, node=self.transacting_power.account), + node_index=self.ritual_tracker.get_node_index( + ritual_id=ritual_id, node=self.transacting_power.account + ), ritual_id=ritual_id, transcript=transcript, - transacting_power=self.transacting_power + transacting_power=self.transacting_power, ) + + self.log.debug(f"completed round 1 of DKG ritual #{ritual_id}") return receipt - def handle_start_confirmation_round(self, ritual_id: int, timestamp: int, *args, **kwargs): - """Confirm the DKG transcripts on the blockchain.""" - # from the tracker's internal cache - ritual = self.ritual_tracker.rituals[ritual_id] - transcripts = [(p.node, p.transcript) for p in ritual.performances] - confirmed_indexes = list() - for index, (node, transcript) in enumerate(transcripts): - valid = confirm_dkg_transcript(transcript) - if valid: - confirmed_indexes.append(index) - receipt = self.coordinator_agent.post_confirmations( + def perform_round_2(self, ritual_id: int, timestamp: int, *args, **kwargs): + ritual = self.get_ritual(ritual_id) + if ritual.status != CoordinatorAgent.Ritual.Status.WAITING_FOR_CONFIRMATIONS: + raise self.ActorError( + f"ritual #{ritual.id} is not waiting for transcripts." + ) + self.log.debug( + f"performing round 2 of DKG ritual #{ritual_id} from blocktime {timestamp}" + ) + + try: + aggregated_transcript = dkg.aggregate_transcripts( + ritual_id=ritual_id, + checksum_address=self.checksum_address, + shares=len(ritual.nodes), + threshold=ritual.threshold, + nodes=ritual.nodes, + transcripts=ritual.transcripts, + ) # TODO: Error handling + except FerveoError: + raise self.ActorError( + f"error aggregating DKG transcript fr ritual #{ritual_id}" + ) + + self.store_aggregated_transcript( + ritual_id=ritual_id, aggregated_transcript=aggregated_transcript + ) + + receipt = self.coordinator_agent.post_aggregation( ritual_id=ritual_id, - node_index=self.ritual_tracker.get_node_index(ritual_id=ritual_id, node=self.transacting_power.account), - confirmed_indexes=confirmed_indexes, + node_index=self.ritual_tracker.get_node_index( + ritual_id=ritual_id, node=self.transacting_power.account + ), + aggregated_transcript=aggregated_transcript, transacting_power=self.transacting_power ) + self.log.debug(f"completed round 2 of DKG ritual #{ritual_id}") return receipt + def derive_decryption_share( + self, ritual_id: int, ciphertext: bytes, conditions: ConditionLingo + ) -> DecryptionShare: + ritual = self.get_ritual(ritual_id) + if ritual.status != CoordinatorAgent.Ritual.Status.FINAL: + raise self.ActorError(f"ritual #{ritual.id} is not finalized.") + aggregated_transcript = self.get_aggregated_transcript(ritual_id) + + try: + decryption_share = dkg.derive_decryption_share( + ritual_id=ritual_id, + checksum_address=self.checksum_address, + shares=len(ritual.nodes), + threshold=ritual.threshold, + nodes=ritual.nodes, + aggregated_transcript=aggregated_transcript, + keypair=self.crypto_power(RitualPower), + ciphertext=bytes(ciphertext), + aad=bytes(conditions), + ) + except FerveoError: + raise self.ActorError( + f"error deriving decryption share for ritual #{ritual_id}" + ) + + return decryption_share + class PolicyAuthor(NucypherTokenActor): """Alice base class for blockchain operations, mocking up new policies!""" diff --git a/nucypher/blockchain/eth/agents.py b/nucypher/blockchain/eth/agents.py index 6746d1006..b400d6370 100644 --- a/nucypher/blockchain/eth/agents.py +++ b/nucypher/blockchain/eth/agents.py @@ -545,58 +545,81 @@ class PREApplicationAgent(EthereumContractAgent): class CoordinatorAgent(EthereumContractAgent): - DKG_SIZE = 8 # TODO: get this from the contract - - contract_name: str = 'CoordinatorV1' + contract_name: str = "CoordinatorV3" _proxy_name = None - @dataclass - class RitualStatus: - WAITING_FOR_CHECKINS = 0 - WAITING_FOR_TRANSCRIPTS = 1 - WAITING_FOR_CONFIRMATIONS = 2 - COMPLETED = 3 - FAILED = 4 - - @dataclass - class Performance: - node: ChecksumAddress - confirmed_by: List = field(default_factory=list) - transcript: bytes = bytes() - checkin_timestamp: int = 0 - @dataclass class Ritual: + + @dataclass + class Status: + WAITING_FOR_CHECKINS = 0 + WAITING_FOR_TRANSCRIPTS = 1 + WAITING_FOR_CONFIRMATIONS = 2 + COMPLETED = 3 + FAILED = 4 + FINAL = 5 + + @dataclass + class Performance: + node: ChecksumAddress + aggregated: bool + transcript: bytes + id: int - status: int - init_timestamp: int = 0 - total_checkins: int = 0 - total_transcripts: int = 0 - total_confirmations: int = 0 + status: Status + init_timestamp: int + total_transcripts: int + total_aggregations: int performances: List = field(default_factory=list) + @property + def nodes(self): + return [p.node for p in self.performances] + + @property + def transcripts(self) -> List[bytes]: + transcripts = list() + for p in self.performances: + if p.aggregated: + raise RuntimeError(f"{p.node[:8]} transcript is already aggregated") + transcripts.append(p.transcript) + return transcripts + + @property + def aggregated_transcripts(self) -> List[bytes]: + transcripts = list() + for p in self.performances: + if not p.aggregated: + raise RuntimeError(f"{p.node[:8]} transcript not aggregated") + transcripts.append(p.transcript) + return transcripts + + @property + def shares(self) -> int: + return len(self.nodes) + @contract_api(CONTRACT_CALL) def get_ritual(self, ritual_id: int) -> Ritual: result = self.contract.functions.rituals(int(ritual_id)).call() - ritual = self.Ritual(id=ritual_id, - status=result[0], - init_timestamp=result[1], - total_checkins=result[2], - total_transcripts=result[3], - total_confirmations=result[4], - performances=[]) + ritual = self.Ritual( + id=ritual_id, + status=result[0], + init_timestamp=result[1], + total_transcripts=result[2], + total_aggregations=result[3], + performances=[], + ) return ritual @contract_api(CONTRACT_CALL) - def get_performances(self, ritual_id: int) -> List[Performance]: + def get_performances(self, ritual_id: int) -> List[Ritual.Performance]: result = self.contract.functions.getPerformances(ritual_id).call() performances = list() for r in result: - performance = self.Performance( - node=ChecksumAddress(r[0]), - confirmed_by=r[1], - transcript=bytes(r[2]), - checkin_timestamp=int.from_bytes(r[3], 'big')) + performance = self.Ritual.Performance( + node=ChecksumAddress(r[0]), aggregated=r[1], transcript=bytes(r[2]) + ) performances.append(performance) return performances @@ -607,23 +630,13 @@ class CoordinatorAgent(EthereumContractAgent): @contract_api(TRANSACTION) def initiate_ritual(self, nodes: List[ChecksumAddress], transacting_power: TransactingPower) -> TxReceipt: - """For use by threshold operator accounts only.""" contract_function: ContractFunction = self.contract.functions.initiateRitual(nodes=nodes) receipt = self.blockchain.send_transaction(contract_function=contract_function, transacting_power=transacting_power) return receipt - @contract_api(TRANSACTION) - def checkin(self, ritual_id: int, node_index: int, transacting_power: TransactingPower) -> TxReceipt: - """For use by threshold operator accounts only.""" - contract_function: ContractFunction = self.contract.functions.checkIn(ritual_id, node_index) - receipt = self.blockchain.send_transaction(contract_function=contract_function, - transacting_power=transacting_power) - return receipt - @contract_api(TRANSACTION) def post_transcript(self, ritual_id: int, transcript: bytes, node_index: int, transacting_power: TransactingPower) -> TxReceipt: - """For use by threshold operator accounts only.""" contract_function: ContractFunction = self.contract.functions.postTranscript( ritualId=ritual_id, nodeIndex=node_index, @@ -634,15 +647,33 @@ class CoordinatorAgent(EthereumContractAgent): return receipt @contract_api(TRANSACTION) - def post_confirmations(self, ritual_id: int, node_index: int, confirmed_indexes: List[int], transacting_power: TransactingPower) -> TxReceipt: - """For use by threshold operator accounts only.""" - contract_function: ContractFunction = self.contract.functions.postConfirmation( + def post_aggregation( + self, + ritual_id: int, + node_index: int, + confirmed_indexes: List[int], + transacting_power: TransactingPower, + ) -> TxReceipt: + contract_function: ContractFunction = self.contract.functions.postAggregation( ritualId=ritual_id, nodeIndex=node_index, confirmedNodesIndexes=confirmed_indexes, ) - receipt = self.blockchain.send_transaction(contract_function=contract_function, - transacting_power=transacting_power) + receipt = self.blockchain.send_transaction( + contract_function=contract_function, transacting_power=transacting_power + ) + return receipt + + @contract_api(TRANSACTION) + def finalize_ritual( + self, ritual_id: int, transacting_power: TransactingPower + ) -> TxReceipt: + contract_function: ContractFunction = self.contract.functions.finalize( + ritual_id + ) + receipt = self.blockchain.send_transaction( + contract_function=contract_function, transacting_power=transacting_power + ) return receipt diff --git a/nucypher/blockchain/eth/trackers/dkg.py b/nucypher/blockchain/eth/trackers/dkg.py index ad831b9d6..e51e4a829 100644 --- a/nucypher/blockchain/eth/trackers/dkg.py +++ b/nucypher/blockchain/eth/trackers/dkg.py @@ -1,14 +1,18 @@ import os - import time +from typing import Callable, List, Tuple, Type, Union + from eth_typing import ChecksumAddress -from twisted.internet import reactor, threads -from twisted.internet.defer import Deferred -from twisted.internet.threads import deferToThread -from typing import Callable, List, Optional, Tuple, Union, Type +from ferveo import ( + AggregatedTranscript, + DecryptionShare, + Dkg, + Keypair, + PublicKey, + Transcript, +) +from twisted.internet import threads from web3 import Web3 -# Currently this method is not exposed over official web3 API, -# but we need it to construct eth_getLogs parameters from web3.contract import Contract, ContractEvent from web3.datastructures import AttributeDict from web3.providers import BaseProvider @@ -64,7 +68,7 @@ class RitualTracker: self.log = Logger("RitualTracker") self.ritualist = ritualist - self.rituals = dict() # TODO: use persistent storage + self.rituals = dict() # TODO: use persistent storage? self.eth_provider = eth_provider self.contract = contract @@ -77,16 +81,17 @@ class RitualTracker: # Map events to handlers self.actions = { - contract.events.StartRitual: self.ritualist.handle_start_ritual, - contract.events.StartTranscriptRound: self.ritualist.handle_start_transcript_round, - contract.events.StartConfirmationRound: self.ritualist.handle_start_confirmation_round, + contract.events.StartTranscriptRound: self.ritualist.perform_round_1, + contract.events.StartConfirmationRound: self.ritualist.perform_round_2, } self.events = list(self.actions) self.provider = eth_provider # Remove the default JSON-RPC retry middleware # as it correctly cannot handle eth_getLogs block range throttle down. - self.provider._middlewares = tuple() + self.provider._middlewares = ( + tuple() + ) # TODO: Do this more precisely to not unintentionally remove other middlewares self.web3 = Web3(self.provider) self.scanner = EventActuator( diff --git a/nucypher/crypto/dkg.py b/nucypher/crypto/dkg.py new file mode 100644 index 000000000..67925871d --- /dev/null +++ b/nucypher/crypto/dkg.py @@ -0,0 +1,68 @@ +# Based on original work here: +# https://github.com/nucypher/ferveo/blob/client-server-api/ferveo-python/examples/server_api.py + +from typing import List, Tuple + +from eth_typing import ChecksumAddress +from ferveo import ( + AggregatedTranscript, + DecryptionShare, + Dkg, + Keypair, + PublicKey, + Transcript, +) + + +def __make_dkg( + ritual_id: int, + checksum_address: ChecksumAddress, + shares: int, + threshold: int, + nodes: List[ChecksumAddress], +) -> Dkg: + _dkg = Dkg( + tau=ritual_id, + shares_num=shares, + security_threshold=threshold, + validators=nodes, + me=checksum_address, + ) + return _dkg + + +def generate_dkg_keypair() -> Keypair: + return Keypair.random() + + +def generate_transcript(*args, **kwargs) -> Transcript: + _dkg = __make_dkg(*args, **kwargs) + transcript = _dkg.generate_transcript() + return transcript + + +def aggregate_transcripts( + transcripts: List[bytes], *args, **kwargs +) -> Tuple[AggregatedTranscript, PublicKey]: + _dkg = __make_dkg(*args, **kwargs) + pvss_aggregated = _dkg.aggregate_transcripts(transcripts) + if not pvss_aggregated.validate(_dkg): + raise Exception("validation failed") # TODO: better exception + public_key = _dkg.final_key + return pvss_aggregated, public_key + + +def derive_decryption_share( + aggregated_transcript: AggregatedTranscript, + keypair: Keypair, + ciphertext: bytes, + aad: bytes, + *args, + **kwargs +) -> DecryptionShare: + dkg = __make_dkg(*args, **kwargs) + assert aggregated_transcript.validate(dkg) + decryption_share = aggregated_transcript.create_decryption_share( + dkg, ciphertext, aad, keypair + ) + return decryption_share diff --git a/tests/mock/ferveo.py b/tests/mock/ferveo.py deleted file mode 100644 index d02d4df95..000000000 --- a/tests/mock/ferveo.py +++ /dev/null @@ -1,34 +0,0 @@ -from typing import Tuple, Dict - -import os - - -class DKGRitual: - domain: bytes - index: int - node: bytes - pvss_params: Dict[str, bytes] - - -def generate_dkg_blinding_keypair(*args, **kwargs) -> Tuple[bytes, bytes]: - """ferveo generate_blinding_keypair""" - return os.urandom(32), os.urandom(32) - - -def generate_dkg_ritual(*args, **kwargs) ->: - """ferveo generate_ritual""" - return DKGRitual(0) - - -def generate_dkg_transcript(dkg, shares: int, *args, **kwargs) -> bytes: - """ferveo generate PVSS""" - return os.urandom(32) - - -def confirm_dkg_transcript(transcript, *args, **kwargs) -> bool: - """ferveo confirm_transcript""" - return True - - -def compute_dfrag(*args, **kwargs): - """ferveo compute_dfrag"""