CoordinatorV3 adaptation and integration with ferveo-server crate draft.

pull/3091/head
Kieran Prasch 2023-02-10 12:11:24 -08:00
parent ac79212b07
commit bc77489e94
5 changed files with 299 additions and 136 deletions

View File

@ -1,9 +1,9 @@
import json
import time
from decimal import Decimal
from typing import Optional, Tuple, Union, Dict, List
from typing import Optional, Tuple, Union
import maya
import time
from constant_sorrow.constants import FULL
from eth_typing import ChecksumAddress
from hexbytes import HexBytes
@ -15,8 +15,9 @@ from nucypher.blockchain.economics import Economics
from nucypher.blockchain.eth.agents import (
AdjudicatorAgent,
ContractAgency,
CoordinatorAgent,
NucypherTokenAgent,
PREApplicationAgent, CoordinatorAgent,
PREApplicationAgent,
)
from nucypher.blockchain.eth.constants import NULL_ADDRESS
from nucypher.blockchain.eth.decorators import save_receipt, validate_checksum_address
@ -34,12 +35,13 @@ from nucypher.blockchain.eth.token import NU
from nucypher.blockchain.eth.trackers.dkg import RitualTracker
from nucypher.blockchain.eth.trackers.pre import WorkTracker
from nucypher.config.constants import DEFAULT_CONFIG_ROOT
from nucypher.crypto import dkg
from nucypher.crypto.powers import CryptoPower, TransactingPower
from nucypher.network.trackers import OperatorBondedTracker
from nucypher.policy.conditions.lingo import ConditionLingo
from nucypher.policy.payment import ContractPayment
from nucypher.utilities.emitters import StdoutEmitter
from nucypher.utilities.logging import Logger
from tests.mock.ferveo import generate_dkg_transcript, confirm_dkg_transcript
class BaseActor:
@ -445,54 +447,145 @@ class Ritualist(BaseActor):
contract=self.coordinator_agent.contract
)
def handle_start_ritual(self, ritual_id: int, timestamp: int, nodes: List[ChecksumAddress], *args, **kwargs):
"""Check in with the coordinator."""
# from the tracker's internal cache
ritual = self.ritual_tracker.rituals[ritual_id]
node_index = nodes.index(self.transacting_power.account)
if ritual.performances[node_index].checkin_timestamp != 0:
raise self.RitualError(f"Node {self.transacting_power.account} has already checked in for ritual {ritual_id}")
receipt = self.coordinator_agent.checkin(
ritual_id=ritual_id,
node_index=node_index,
transacting_power=self.transacting_power
)
return receipt
self.dkg_storage = {"transcripts": {}, "aggregated_transcripts": {}}
def handle_start_transcript_round(self, ritual_id: int, timestamp: int, *args, **kwargs):
"""Post a DKG transcript to the blockchain."""
# from the tracker's internal cache
ritual = self.ritual_tracker.rituals[ritual_id]
node_index = self.ritual_tracker.get_node_index(ritual_id=ritual_id, node=self.transacting_power.account)
def get_ritual(self, ritual_id: int) -> CoordinatorAgent.Ritual:
try:
ritual = self.ritual_tracker.rituals[ritual_id]
except KeyError:
raise self.ActorError(f"{ritual_id} is not in the local cache")
return ritual
def store_transcript(self, ritual_id: int, transcript: Transcript) -> None:
self.dkg_storage["transcripts"][ritual_id] = bytes(transcript)
def store_aggregated_transcript(
self, ritual_id: int, aggregated_transcript: AggregatedTranscript
) -> None:
self.dkg_storage["aggregated_transcripts"][ritual_id] = bytes(
aggregated_transcript
)
def get_aggregated_transcript(self, ritual_id: int) -> AggregatedTranscript:
data = self.dkg_storage["aggregated_transcripts"][ritual_id]
aggregated_transcript = AggregatedTranscript.from_bytes(data)
return aggregated_transcript
def get_transcript(self, ritual_id: int) -> AggregatedTranscript:
data = self.dkg_storage["transcripts"][ritual_id]
transcript = Transcript.from_bytes(data)
return transcript
def perform_round_1(self, ritual_id: int, timestamp: int, *args, **kwargs):
ritual = self.get_ritual(ritual_id)
if ritual.status != CoordinatorAgent.Ritual.Status.WAITING_FOR_CONFIRMATIONS:
raise self.ActorError(
f"ritual #{ritual.id} is not waiting for transcripts."
)
node_index = self.ritual_tracker.get_node_index(
ritual_id=ritual_id, node=self.transacting_power.account
)
if ritual.performances[node_index].transcript:
raise self.RitualError(f"Node {self.transacting_power.account} has already posted a transcript for ritual {ritual_id}")
transcript = generate_dkg_transcript()
raise self.RitualError(
f"Node {self.transacting_power.account} has already posted a transcript for ritual {ritual_id}"
)
self.log.debug(
f"performing round 1 of DKG ritual #{ritual_id} from blocktime {timestamp}"
)
try:
transcript = dkg.generate_transcript(
ritual_id=ritual_id,
checksum_address=self.checksum_address,
shares=len(ritual.nodes),
threshold=ritual.threshold,
nodes=ritual.nodes,
) # TODO: Error handling
except FerveroError:
raise self.ActorError(
f"error generating DKG transcript for ritual #{ritual_id}"
)
self.store_transcript(ritual_id=ritual_id, transcript=transcript)
receipt = self.coordinator_agent.post_transcript(
node_index=self.ritual_tracker.get_node_index(ritual_id=ritual_id, node=self.transacting_power.account),
node_index=self.ritual_tracker.get_node_index(
ritual_id=ritual_id, node=self.transacting_power.account
),
ritual_id=ritual_id,
transcript=transcript,
transacting_power=self.transacting_power
transacting_power=self.transacting_power,
)
self.log.debug(f"completed round 1 of DKG ritual #{ritual_id}")
return receipt
def handle_start_confirmation_round(self, ritual_id: int, timestamp: int, *args, **kwargs):
"""Confirm the DKG transcripts on the blockchain."""
# from the tracker's internal cache
ritual = self.ritual_tracker.rituals[ritual_id]
transcripts = [(p.node, p.transcript) for p in ritual.performances]
confirmed_indexes = list()
for index, (node, transcript) in enumerate(transcripts):
valid = confirm_dkg_transcript(transcript)
if valid:
confirmed_indexes.append(index)
receipt = self.coordinator_agent.post_confirmations(
def perform_round_2(self, ritual_id: int, timestamp: int, *args, **kwargs):
ritual = self.get_ritual(ritual_id)
if ritual.status != CoordinatorAgent.Ritual.Status.WAITING_FOR_CONFIRMATIONS:
raise self.ActorError(
f"ritual #{ritual.id} is not waiting for transcripts."
)
self.log.debug(
f"performing round 2 of DKG ritual #{ritual_id} from blocktime {timestamp}"
)
try:
aggregated_transcript = dkg.aggregate_transcripts(
ritual_id=ritual_id,
checksum_address=self.checksum_address,
shares=len(ritual.nodes),
threshold=ritual.threshold,
nodes=ritual.nodes,
transcripts=ritual.transcripts,
) # TODO: Error handling
except FerveoError:
raise self.ActorError(
f"error aggregating DKG transcript fr ritual #{ritual_id}"
)
self.store_aggregated_transcript(
ritual_id=ritual_id, aggregated_transcript=aggregated_transcript
)
receipt = self.coordinator_agent.post_aggregation(
ritual_id=ritual_id,
node_index=self.ritual_tracker.get_node_index(ritual_id=ritual_id, node=self.transacting_power.account),
confirmed_indexes=confirmed_indexes,
node_index=self.ritual_tracker.get_node_index(
ritual_id=ritual_id, node=self.transacting_power.account
),
aggregated_transcript=aggregated_transcript,
transacting_power=self.transacting_power
)
self.log.debug(f"completed round 2 of DKG ritual #{ritual_id}")
return receipt
def derive_decryption_share(
self, ritual_id: int, ciphertext: bytes, conditions: ConditionLingo
) -> DecryptionShare:
ritual = self.get_ritual(ritual_id)
if ritual.status != CoordinatorAgent.Ritual.Status.FINAL:
raise self.ActorError(f"ritual #{ritual.id} is not finalized.")
aggregated_transcript = self.get_aggregated_transcript(ritual_id)
try:
decryption_share = dkg.derive_decryption_share(
ritual_id=ritual_id,
checksum_address=self.checksum_address,
shares=len(ritual.nodes),
threshold=ritual.threshold,
nodes=ritual.nodes,
aggregated_transcript=aggregated_transcript,
keypair=self.crypto_power(RitualPower),
ciphertext=bytes(ciphertext),
aad=bytes(conditions),
)
except FerveoError:
raise self.ActorError(
f"error deriving decryption share for ritual #{ritual_id}"
)
return decryption_share
class PolicyAuthor(NucypherTokenActor):
"""Alice base class for blockchain operations, mocking up new policies!"""

View File

@ -545,58 +545,81 @@ class PREApplicationAgent(EthereumContractAgent):
class CoordinatorAgent(EthereumContractAgent):
DKG_SIZE = 8 # TODO: get this from the contract
contract_name: str = 'CoordinatorV1'
contract_name: str = "CoordinatorV3"
_proxy_name = None
@dataclass
class RitualStatus:
WAITING_FOR_CHECKINS = 0
WAITING_FOR_TRANSCRIPTS = 1
WAITING_FOR_CONFIRMATIONS = 2
COMPLETED = 3
FAILED = 4
@dataclass
class Performance:
node: ChecksumAddress
confirmed_by: List = field(default_factory=list)
transcript: bytes = bytes()
checkin_timestamp: int = 0
@dataclass
class Ritual:
@dataclass
class Status:
WAITING_FOR_CHECKINS = 0
WAITING_FOR_TRANSCRIPTS = 1
WAITING_FOR_CONFIRMATIONS = 2
COMPLETED = 3
FAILED = 4
FINAL = 5
@dataclass
class Performance:
node: ChecksumAddress
aggregated: bool
transcript: bytes
id: int
status: int
init_timestamp: int = 0
total_checkins: int = 0
total_transcripts: int = 0
total_confirmations: int = 0
status: Status
init_timestamp: int
total_transcripts: int
total_aggregations: int
performances: List = field(default_factory=list)
@property
def nodes(self):
return [p.node for p in self.performances]
@property
def transcripts(self) -> List[bytes]:
transcripts = list()
for p in self.performances:
if p.aggregated:
raise RuntimeError(f"{p.node[:8]} transcript is already aggregated")
transcripts.append(p.transcript)
return transcripts
@property
def aggregated_transcripts(self) -> List[bytes]:
transcripts = list()
for p in self.performances:
if not p.aggregated:
raise RuntimeError(f"{p.node[:8]} transcript not aggregated")
transcripts.append(p.transcript)
return transcripts
@property
def shares(self) -> int:
return len(self.nodes)
@contract_api(CONTRACT_CALL)
def get_ritual(self, ritual_id: int) -> Ritual:
result = self.contract.functions.rituals(int(ritual_id)).call()
ritual = self.Ritual(id=ritual_id,
status=result[0],
init_timestamp=result[1],
total_checkins=result[2],
total_transcripts=result[3],
total_confirmations=result[4],
performances=[])
ritual = self.Ritual(
id=ritual_id,
status=result[0],
init_timestamp=result[1],
total_transcripts=result[2],
total_aggregations=result[3],
performances=[],
)
return ritual
@contract_api(CONTRACT_CALL)
def get_performances(self, ritual_id: int) -> List[Performance]:
def get_performances(self, ritual_id: int) -> List[Ritual.Performance]:
result = self.contract.functions.getPerformances(ritual_id).call()
performances = list()
for r in result:
performance = self.Performance(
node=ChecksumAddress(r[0]),
confirmed_by=r[1],
transcript=bytes(r[2]),
checkin_timestamp=int.from_bytes(r[3], 'big'))
performance = self.Ritual.Performance(
node=ChecksumAddress(r[0]), aggregated=r[1], transcript=bytes(r[2])
)
performances.append(performance)
return performances
@ -607,23 +630,13 @@ class CoordinatorAgent(EthereumContractAgent):
@contract_api(TRANSACTION)
def initiate_ritual(self, nodes: List[ChecksumAddress], transacting_power: TransactingPower) -> TxReceipt:
"""For use by threshold operator accounts only."""
contract_function: ContractFunction = self.contract.functions.initiateRitual(nodes=nodes)
receipt = self.blockchain.send_transaction(contract_function=contract_function,
transacting_power=transacting_power)
return receipt
@contract_api(TRANSACTION)
def checkin(self, ritual_id: int, node_index: int, transacting_power: TransactingPower) -> TxReceipt:
"""For use by threshold operator accounts only."""
contract_function: ContractFunction = self.contract.functions.checkIn(ritual_id, node_index)
receipt = self.blockchain.send_transaction(contract_function=contract_function,
transacting_power=transacting_power)
return receipt
@contract_api(TRANSACTION)
def post_transcript(self, ritual_id: int, transcript: bytes, node_index: int, transacting_power: TransactingPower) -> TxReceipt:
"""For use by threshold operator accounts only."""
contract_function: ContractFunction = self.contract.functions.postTranscript(
ritualId=ritual_id,
nodeIndex=node_index,
@ -634,15 +647,33 @@ class CoordinatorAgent(EthereumContractAgent):
return receipt
@contract_api(TRANSACTION)
def post_confirmations(self, ritual_id: int, node_index: int, confirmed_indexes: List[int], transacting_power: TransactingPower) -> TxReceipt:
"""For use by threshold operator accounts only."""
contract_function: ContractFunction = self.contract.functions.postConfirmation(
def post_aggregation(
self,
ritual_id: int,
node_index: int,
confirmed_indexes: List[int],
transacting_power: TransactingPower,
) -> TxReceipt:
contract_function: ContractFunction = self.contract.functions.postAggregation(
ritualId=ritual_id,
nodeIndex=node_index,
confirmedNodesIndexes=confirmed_indexes,
)
receipt = self.blockchain.send_transaction(contract_function=contract_function,
transacting_power=transacting_power)
receipt = self.blockchain.send_transaction(
contract_function=contract_function, transacting_power=transacting_power
)
return receipt
@contract_api(TRANSACTION)
def finalize_ritual(
self, ritual_id: int, transacting_power: TransactingPower
) -> TxReceipt:
contract_function: ContractFunction = self.contract.functions.finalize(
ritual_id
)
receipt = self.blockchain.send_transaction(
contract_function=contract_function, transacting_power=transacting_power
)
return receipt

View File

@ -1,14 +1,18 @@
import os
import time
from typing import Callable, List, Tuple, Type, Union
from eth_typing import ChecksumAddress
from twisted.internet import reactor, threads
from twisted.internet.defer import Deferred
from twisted.internet.threads import deferToThread
from typing import Callable, List, Optional, Tuple, Union, Type
from ferveo import (
AggregatedTranscript,
DecryptionShare,
Dkg,
Keypair,
PublicKey,
Transcript,
)
from twisted.internet import threads
from web3 import Web3
# Currently this method is not exposed over official web3 API,
# but we need it to construct eth_getLogs parameters
from web3.contract import Contract, ContractEvent
from web3.datastructures import AttributeDict
from web3.providers import BaseProvider
@ -64,7 +68,7 @@ class RitualTracker:
self.log = Logger("RitualTracker")
self.ritualist = ritualist
self.rituals = dict() # TODO: use persistent storage
self.rituals = dict() # TODO: use persistent storage?
self.eth_provider = eth_provider
self.contract = contract
@ -77,16 +81,17 @@ class RitualTracker:
# Map events to handlers
self.actions = {
contract.events.StartRitual: self.ritualist.handle_start_ritual,
contract.events.StartTranscriptRound: self.ritualist.handle_start_transcript_round,
contract.events.StartConfirmationRound: self.ritualist.handle_start_confirmation_round,
contract.events.StartTranscriptRound: self.ritualist.perform_round_1,
contract.events.StartConfirmationRound: self.ritualist.perform_round_2,
}
self.events = list(self.actions)
self.provider = eth_provider
# Remove the default JSON-RPC retry middleware
# as it correctly cannot handle eth_getLogs block range throttle down.
self.provider._middlewares = tuple()
self.provider._middlewares = (
tuple()
) # TODO: Do this more precisely to not unintentionally remove other middlewares
self.web3 = Web3(self.provider)
self.scanner = EventActuator(

68
nucypher/crypto/dkg.py Normal file
View File

@ -0,0 +1,68 @@
# Based on original work here:
# https://github.com/nucypher/ferveo/blob/client-server-api/ferveo-python/examples/server_api.py
from typing import List, Tuple
from eth_typing import ChecksumAddress
from ferveo import (
AggregatedTranscript,
DecryptionShare,
Dkg,
Keypair,
PublicKey,
Transcript,
)
def __make_dkg(
ritual_id: int,
checksum_address: ChecksumAddress,
shares: int,
threshold: int,
nodes: List[ChecksumAddress],
) -> Dkg:
_dkg = Dkg(
tau=ritual_id,
shares_num=shares,
security_threshold=threshold,
validators=nodes,
me=checksum_address,
)
return _dkg
def generate_dkg_keypair() -> Keypair:
return Keypair.random()
def generate_transcript(*args, **kwargs) -> Transcript:
_dkg = __make_dkg(*args, **kwargs)
transcript = _dkg.generate_transcript()
return transcript
def aggregate_transcripts(
transcripts: List[bytes], *args, **kwargs
) -> Tuple[AggregatedTranscript, PublicKey]:
_dkg = __make_dkg(*args, **kwargs)
pvss_aggregated = _dkg.aggregate_transcripts(transcripts)
if not pvss_aggregated.validate(_dkg):
raise Exception("validation failed") # TODO: better exception
public_key = _dkg.final_key
return pvss_aggregated, public_key
def derive_decryption_share(
aggregated_transcript: AggregatedTranscript,
keypair: Keypair,
ciphertext: bytes,
aad: bytes,
*args,
**kwargs
) -> DecryptionShare:
dkg = __make_dkg(*args, **kwargs)
assert aggregated_transcript.validate(dkg)
decryption_share = aggregated_transcript.create_decryption_share(
dkg, ciphertext, aad, keypair
)
return decryption_share

View File

@ -1,34 +0,0 @@
from typing import Tuple, Dict
import os
class DKGRitual:
domain: bytes
index: int
node: bytes
pvss_params: Dict[str, bytes]
def generate_dkg_blinding_keypair(*args, **kwargs) -> Tuple[bytes, bytes]:
"""ferveo generate_blinding_keypair"""
return os.urandom(32), os.urandom(32)
def generate_dkg_ritual(*args, **kwargs) ->:
"""ferveo generate_ritual"""
return DKGRitual(0)
def generate_dkg_transcript(dkg, shares: int, *args, **kwargs) -> bytes:
"""ferveo generate PVSS"""
return os.urandom(32)
def confirm_dkg_transcript(transcript, *args, **kwargs) -> bool:
"""ferveo confirm_transcript"""
return True
def compute_dfrag(*args, **kwargs):
"""ferveo compute_dfrag"""