CoordinatorV3 adaptation and integration with ferveo-server crate draft.

pull/3091/head
Kieran Prasch 2023-02-10 12:11:24 -08:00
parent ac79212b07
commit bc77489e94
5 changed files with 299 additions and 136 deletions

View File

@ -1,9 +1,9 @@
import json import json
import time
from decimal import Decimal from decimal import Decimal
from typing import Optional, Tuple, Union, Dict, List from typing import Optional, Tuple, Union
import maya import maya
import time
from constant_sorrow.constants import FULL from constant_sorrow.constants import FULL
from eth_typing import ChecksumAddress from eth_typing import ChecksumAddress
from hexbytes import HexBytes from hexbytes import HexBytes
@ -15,8 +15,9 @@ from nucypher.blockchain.economics import Economics
from nucypher.blockchain.eth.agents import ( from nucypher.blockchain.eth.agents import (
AdjudicatorAgent, AdjudicatorAgent,
ContractAgency, ContractAgency,
CoordinatorAgent,
NucypherTokenAgent, NucypherTokenAgent,
PREApplicationAgent, CoordinatorAgent, PREApplicationAgent,
) )
from nucypher.blockchain.eth.constants import NULL_ADDRESS from nucypher.blockchain.eth.constants import NULL_ADDRESS
from nucypher.blockchain.eth.decorators import save_receipt, validate_checksum_address from nucypher.blockchain.eth.decorators import save_receipt, validate_checksum_address
@ -34,12 +35,13 @@ from nucypher.blockchain.eth.token import NU
from nucypher.blockchain.eth.trackers.dkg import RitualTracker from nucypher.blockchain.eth.trackers.dkg import RitualTracker
from nucypher.blockchain.eth.trackers.pre import WorkTracker from nucypher.blockchain.eth.trackers.pre import WorkTracker
from nucypher.config.constants import DEFAULT_CONFIG_ROOT from nucypher.config.constants import DEFAULT_CONFIG_ROOT
from nucypher.crypto import dkg
from nucypher.crypto.powers import CryptoPower, TransactingPower from nucypher.crypto.powers import CryptoPower, TransactingPower
from nucypher.network.trackers import OperatorBondedTracker from nucypher.network.trackers import OperatorBondedTracker
from nucypher.policy.conditions.lingo import ConditionLingo
from nucypher.policy.payment import ContractPayment from nucypher.policy.payment import ContractPayment
from nucypher.utilities.emitters import StdoutEmitter from nucypher.utilities.emitters import StdoutEmitter
from nucypher.utilities.logging import Logger from nucypher.utilities.logging import Logger
from tests.mock.ferveo import generate_dkg_transcript, confirm_dkg_transcript
class BaseActor: class BaseActor:
@ -445,54 +447,145 @@ class Ritualist(BaseActor):
contract=self.coordinator_agent.contract contract=self.coordinator_agent.contract
) )
def handle_start_ritual(self, ritual_id: int, timestamp: int, nodes: List[ChecksumAddress], *args, **kwargs): self.dkg_storage = {"transcripts": {}, "aggregated_transcripts": {}}
"""Check in with the coordinator."""
# from the tracker's internal cache
ritual = self.ritual_tracker.rituals[ritual_id]
node_index = nodes.index(self.transacting_power.account)
if ritual.performances[node_index].checkin_timestamp != 0:
raise self.RitualError(f"Node {self.transacting_power.account} has already checked in for ritual {ritual_id}")
receipt = self.coordinator_agent.checkin(
ritual_id=ritual_id,
node_index=node_index,
transacting_power=self.transacting_power
)
return receipt
def handle_start_transcript_round(self, ritual_id: int, timestamp: int, *args, **kwargs): def get_ritual(self, ritual_id: int) -> CoordinatorAgent.Ritual:
"""Post a DKG transcript to the blockchain.""" try:
# from the tracker's internal cache ritual = self.ritual_tracker.rituals[ritual_id]
ritual = self.ritual_tracker.rituals[ritual_id] except KeyError:
node_index = self.ritual_tracker.get_node_index(ritual_id=ritual_id, node=self.transacting_power.account) raise self.ActorError(f"{ritual_id} is not in the local cache")
return ritual
def store_transcript(self, ritual_id: int, transcript: Transcript) -> None:
self.dkg_storage["transcripts"][ritual_id] = bytes(transcript)
def store_aggregated_transcript(
self, ritual_id: int, aggregated_transcript: AggregatedTranscript
) -> None:
self.dkg_storage["aggregated_transcripts"][ritual_id] = bytes(
aggregated_transcript
)
def get_aggregated_transcript(self, ritual_id: int) -> AggregatedTranscript:
data = self.dkg_storage["aggregated_transcripts"][ritual_id]
aggregated_transcript = AggregatedTranscript.from_bytes(data)
return aggregated_transcript
def get_transcript(self, ritual_id: int) -> AggregatedTranscript:
data = self.dkg_storage["transcripts"][ritual_id]
transcript = Transcript.from_bytes(data)
return transcript
def perform_round_1(self, ritual_id: int, timestamp: int, *args, **kwargs):
ritual = self.get_ritual(ritual_id)
if ritual.status != CoordinatorAgent.Ritual.Status.WAITING_FOR_CONFIRMATIONS:
raise self.ActorError(
f"ritual #{ritual.id} is not waiting for transcripts."
)
node_index = self.ritual_tracker.get_node_index(
ritual_id=ritual_id, node=self.transacting_power.account
)
if ritual.performances[node_index].transcript: if ritual.performances[node_index].transcript:
raise self.RitualError(f"Node {self.transacting_power.account} has already posted a transcript for ritual {ritual_id}") raise self.RitualError(
transcript = generate_dkg_transcript() f"Node {self.transacting_power.account} has already posted a transcript for ritual {ritual_id}"
)
self.log.debug(
f"performing round 1 of DKG ritual #{ritual_id} from blocktime {timestamp}"
)
try:
transcript = dkg.generate_transcript(
ritual_id=ritual_id,
checksum_address=self.checksum_address,
shares=len(ritual.nodes),
threshold=ritual.threshold,
nodes=ritual.nodes,
) # TODO: Error handling
except FerveroError:
raise self.ActorError(
f"error generating DKG transcript for ritual #{ritual_id}"
)
self.store_transcript(ritual_id=ritual_id, transcript=transcript)
receipt = self.coordinator_agent.post_transcript( receipt = self.coordinator_agent.post_transcript(
node_index=self.ritual_tracker.get_node_index(ritual_id=ritual_id, node=self.transacting_power.account), node_index=self.ritual_tracker.get_node_index(
ritual_id=ritual_id, node=self.transacting_power.account
),
ritual_id=ritual_id, ritual_id=ritual_id,
transcript=transcript, transcript=transcript,
transacting_power=self.transacting_power transacting_power=self.transacting_power,
) )
self.log.debug(f"completed round 1 of DKG ritual #{ritual_id}")
return receipt return receipt
def handle_start_confirmation_round(self, ritual_id: int, timestamp: int, *args, **kwargs): def perform_round_2(self, ritual_id: int, timestamp: int, *args, **kwargs):
"""Confirm the DKG transcripts on the blockchain.""" ritual = self.get_ritual(ritual_id)
# from the tracker's internal cache if ritual.status != CoordinatorAgent.Ritual.Status.WAITING_FOR_CONFIRMATIONS:
ritual = self.ritual_tracker.rituals[ritual_id] raise self.ActorError(
transcripts = [(p.node, p.transcript) for p in ritual.performances] f"ritual #{ritual.id} is not waiting for transcripts."
confirmed_indexes = list() )
for index, (node, transcript) in enumerate(transcripts): self.log.debug(
valid = confirm_dkg_transcript(transcript) f"performing round 2 of DKG ritual #{ritual_id} from blocktime {timestamp}"
if valid: )
confirmed_indexes.append(index)
receipt = self.coordinator_agent.post_confirmations( try:
aggregated_transcript = dkg.aggregate_transcripts(
ritual_id=ritual_id,
checksum_address=self.checksum_address,
shares=len(ritual.nodes),
threshold=ritual.threshold,
nodes=ritual.nodes,
transcripts=ritual.transcripts,
) # TODO: Error handling
except FerveoError:
raise self.ActorError(
f"error aggregating DKG transcript fr ritual #{ritual_id}"
)
self.store_aggregated_transcript(
ritual_id=ritual_id, aggregated_transcript=aggregated_transcript
)
receipt = self.coordinator_agent.post_aggregation(
ritual_id=ritual_id, ritual_id=ritual_id,
node_index=self.ritual_tracker.get_node_index(ritual_id=ritual_id, node=self.transacting_power.account), node_index=self.ritual_tracker.get_node_index(
confirmed_indexes=confirmed_indexes, ritual_id=ritual_id, node=self.transacting_power.account
),
aggregated_transcript=aggregated_transcript,
transacting_power=self.transacting_power transacting_power=self.transacting_power
) )
self.log.debug(f"completed round 2 of DKG ritual #{ritual_id}")
return receipt return receipt
def derive_decryption_share(
self, ritual_id: int, ciphertext: bytes, conditions: ConditionLingo
) -> DecryptionShare:
ritual = self.get_ritual(ritual_id)
if ritual.status != CoordinatorAgent.Ritual.Status.FINAL:
raise self.ActorError(f"ritual #{ritual.id} is not finalized.")
aggregated_transcript = self.get_aggregated_transcript(ritual_id)
try:
decryption_share = dkg.derive_decryption_share(
ritual_id=ritual_id,
checksum_address=self.checksum_address,
shares=len(ritual.nodes),
threshold=ritual.threshold,
nodes=ritual.nodes,
aggregated_transcript=aggregated_transcript,
keypair=self.crypto_power(RitualPower),
ciphertext=bytes(ciphertext),
aad=bytes(conditions),
)
except FerveoError:
raise self.ActorError(
f"error deriving decryption share for ritual #{ritual_id}"
)
return decryption_share
class PolicyAuthor(NucypherTokenActor): class PolicyAuthor(NucypherTokenActor):
"""Alice base class for blockchain operations, mocking up new policies!""" """Alice base class for blockchain operations, mocking up new policies!"""

View File

@ -545,58 +545,81 @@ class PREApplicationAgent(EthereumContractAgent):
class CoordinatorAgent(EthereumContractAgent): class CoordinatorAgent(EthereumContractAgent):
DKG_SIZE = 8 # TODO: get this from the contract contract_name: str = "CoordinatorV3"
contract_name: str = 'CoordinatorV1'
_proxy_name = None _proxy_name = None
@dataclass
class RitualStatus:
WAITING_FOR_CHECKINS = 0
WAITING_FOR_TRANSCRIPTS = 1
WAITING_FOR_CONFIRMATIONS = 2
COMPLETED = 3
FAILED = 4
@dataclass
class Performance:
node: ChecksumAddress
confirmed_by: List = field(default_factory=list)
transcript: bytes = bytes()
checkin_timestamp: int = 0
@dataclass @dataclass
class Ritual: class Ritual:
@dataclass
class Status:
WAITING_FOR_CHECKINS = 0
WAITING_FOR_TRANSCRIPTS = 1
WAITING_FOR_CONFIRMATIONS = 2
COMPLETED = 3
FAILED = 4
FINAL = 5
@dataclass
class Performance:
node: ChecksumAddress
aggregated: bool
transcript: bytes
id: int id: int
status: int status: Status
init_timestamp: int = 0 init_timestamp: int
total_checkins: int = 0 total_transcripts: int
total_transcripts: int = 0 total_aggregations: int
total_confirmations: int = 0
performances: List = field(default_factory=list) performances: List = field(default_factory=list)
@property
def nodes(self):
return [p.node for p in self.performances]
@property
def transcripts(self) -> List[bytes]:
transcripts = list()
for p in self.performances:
if p.aggregated:
raise RuntimeError(f"{p.node[:8]} transcript is already aggregated")
transcripts.append(p.transcript)
return transcripts
@property
def aggregated_transcripts(self) -> List[bytes]:
transcripts = list()
for p in self.performances:
if not p.aggregated:
raise RuntimeError(f"{p.node[:8]} transcript not aggregated")
transcripts.append(p.transcript)
return transcripts
@property
def shares(self) -> int:
return len(self.nodes)
@contract_api(CONTRACT_CALL) @contract_api(CONTRACT_CALL)
def get_ritual(self, ritual_id: int) -> Ritual: def get_ritual(self, ritual_id: int) -> Ritual:
result = self.contract.functions.rituals(int(ritual_id)).call() result = self.contract.functions.rituals(int(ritual_id)).call()
ritual = self.Ritual(id=ritual_id, ritual = self.Ritual(
status=result[0], id=ritual_id,
init_timestamp=result[1], status=result[0],
total_checkins=result[2], init_timestamp=result[1],
total_transcripts=result[3], total_transcripts=result[2],
total_confirmations=result[4], total_aggregations=result[3],
performances=[]) performances=[],
)
return ritual return ritual
@contract_api(CONTRACT_CALL) @contract_api(CONTRACT_CALL)
def get_performances(self, ritual_id: int) -> List[Performance]: def get_performances(self, ritual_id: int) -> List[Ritual.Performance]:
result = self.contract.functions.getPerformances(ritual_id).call() result = self.contract.functions.getPerformances(ritual_id).call()
performances = list() performances = list()
for r in result: for r in result:
performance = self.Performance( performance = self.Ritual.Performance(
node=ChecksumAddress(r[0]), node=ChecksumAddress(r[0]), aggregated=r[1], transcript=bytes(r[2])
confirmed_by=r[1], )
transcript=bytes(r[2]),
checkin_timestamp=int.from_bytes(r[3], 'big'))
performances.append(performance) performances.append(performance)
return performances return performances
@ -607,23 +630,13 @@ class CoordinatorAgent(EthereumContractAgent):
@contract_api(TRANSACTION) @contract_api(TRANSACTION)
def initiate_ritual(self, nodes: List[ChecksumAddress], transacting_power: TransactingPower) -> TxReceipt: def initiate_ritual(self, nodes: List[ChecksumAddress], transacting_power: TransactingPower) -> TxReceipt:
"""For use by threshold operator accounts only."""
contract_function: ContractFunction = self.contract.functions.initiateRitual(nodes=nodes) contract_function: ContractFunction = self.contract.functions.initiateRitual(nodes=nodes)
receipt = self.blockchain.send_transaction(contract_function=contract_function, receipt = self.blockchain.send_transaction(contract_function=contract_function,
transacting_power=transacting_power) transacting_power=transacting_power)
return receipt return receipt
@contract_api(TRANSACTION)
def checkin(self, ritual_id: int, node_index: int, transacting_power: TransactingPower) -> TxReceipt:
"""For use by threshold operator accounts only."""
contract_function: ContractFunction = self.contract.functions.checkIn(ritual_id, node_index)
receipt = self.blockchain.send_transaction(contract_function=contract_function,
transacting_power=transacting_power)
return receipt
@contract_api(TRANSACTION) @contract_api(TRANSACTION)
def post_transcript(self, ritual_id: int, transcript: bytes, node_index: int, transacting_power: TransactingPower) -> TxReceipt: def post_transcript(self, ritual_id: int, transcript: bytes, node_index: int, transacting_power: TransactingPower) -> TxReceipt:
"""For use by threshold operator accounts only."""
contract_function: ContractFunction = self.contract.functions.postTranscript( contract_function: ContractFunction = self.contract.functions.postTranscript(
ritualId=ritual_id, ritualId=ritual_id,
nodeIndex=node_index, nodeIndex=node_index,
@ -634,15 +647,33 @@ class CoordinatorAgent(EthereumContractAgent):
return receipt return receipt
@contract_api(TRANSACTION) @contract_api(TRANSACTION)
def post_confirmations(self, ritual_id: int, node_index: int, confirmed_indexes: List[int], transacting_power: TransactingPower) -> TxReceipt: def post_aggregation(
"""For use by threshold operator accounts only.""" self,
contract_function: ContractFunction = self.contract.functions.postConfirmation( ritual_id: int,
node_index: int,
confirmed_indexes: List[int],
transacting_power: TransactingPower,
) -> TxReceipt:
contract_function: ContractFunction = self.contract.functions.postAggregation(
ritualId=ritual_id, ritualId=ritual_id,
nodeIndex=node_index, nodeIndex=node_index,
confirmedNodesIndexes=confirmed_indexes, confirmedNodesIndexes=confirmed_indexes,
) )
receipt = self.blockchain.send_transaction(contract_function=contract_function, receipt = self.blockchain.send_transaction(
transacting_power=transacting_power) contract_function=contract_function, transacting_power=transacting_power
)
return receipt
@contract_api(TRANSACTION)
def finalize_ritual(
self, ritual_id: int, transacting_power: TransactingPower
) -> TxReceipt:
contract_function: ContractFunction = self.contract.functions.finalize(
ritual_id
)
receipt = self.blockchain.send_transaction(
contract_function=contract_function, transacting_power=transacting_power
)
return receipt return receipt

View File

@ -1,14 +1,18 @@
import os import os
import time import time
from typing import Callable, List, Tuple, Type, Union
from eth_typing import ChecksumAddress from eth_typing import ChecksumAddress
from twisted.internet import reactor, threads from ferveo import (
from twisted.internet.defer import Deferred AggregatedTranscript,
from twisted.internet.threads import deferToThread DecryptionShare,
from typing import Callable, List, Optional, Tuple, Union, Type Dkg,
Keypair,
PublicKey,
Transcript,
)
from twisted.internet import threads
from web3 import Web3 from web3 import Web3
# Currently this method is not exposed over official web3 API,
# but we need it to construct eth_getLogs parameters
from web3.contract import Contract, ContractEvent from web3.contract import Contract, ContractEvent
from web3.datastructures import AttributeDict from web3.datastructures import AttributeDict
from web3.providers import BaseProvider from web3.providers import BaseProvider
@ -64,7 +68,7 @@ class RitualTracker:
self.log = Logger("RitualTracker") self.log = Logger("RitualTracker")
self.ritualist = ritualist self.ritualist = ritualist
self.rituals = dict() # TODO: use persistent storage self.rituals = dict() # TODO: use persistent storage?
self.eth_provider = eth_provider self.eth_provider = eth_provider
self.contract = contract self.contract = contract
@ -77,16 +81,17 @@ class RitualTracker:
# Map events to handlers # Map events to handlers
self.actions = { self.actions = {
contract.events.StartRitual: self.ritualist.handle_start_ritual, contract.events.StartTranscriptRound: self.ritualist.perform_round_1,
contract.events.StartTranscriptRound: self.ritualist.handle_start_transcript_round, contract.events.StartConfirmationRound: self.ritualist.perform_round_2,
contract.events.StartConfirmationRound: self.ritualist.handle_start_confirmation_round,
} }
self.events = list(self.actions) self.events = list(self.actions)
self.provider = eth_provider self.provider = eth_provider
# Remove the default JSON-RPC retry middleware # Remove the default JSON-RPC retry middleware
# as it correctly cannot handle eth_getLogs block range throttle down. # as it correctly cannot handle eth_getLogs block range throttle down.
self.provider._middlewares = tuple() self.provider._middlewares = (
tuple()
) # TODO: Do this more precisely to not unintentionally remove other middlewares
self.web3 = Web3(self.provider) self.web3 = Web3(self.provider)
self.scanner = EventActuator( self.scanner = EventActuator(

68
nucypher/crypto/dkg.py Normal file
View File

@ -0,0 +1,68 @@
# Based on original work here:
# https://github.com/nucypher/ferveo/blob/client-server-api/ferveo-python/examples/server_api.py
from typing import List, Tuple
from eth_typing import ChecksumAddress
from ferveo import (
AggregatedTranscript,
DecryptionShare,
Dkg,
Keypair,
PublicKey,
Transcript,
)
def __make_dkg(
ritual_id: int,
checksum_address: ChecksumAddress,
shares: int,
threshold: int,
nodes: List[ChecksumAddress],
) -> Dkg:
_dkg = Dkg(
tau=ritual_id,
shares_num=shares,
security_threshold=threshold,
validators=nodes,
me=checksum_address,
)
return _dkg
def generate_dkg_keypair() -> Keypair:
return Keypair.random()
def generate_transcript(*args, **kwargs) -> Transcript:
_dkg = __make_dkg(*args, **kwargs)
transcript = _dkg.generate_transcript()
return transcript
def aggregate_transcripts(
transcripts: List[bytes], *args, **kwargs
) -> Tuple[AggregatedTranscript, PublicKey]:
_dkg = __make_dkg(*args, **kwargs)
pvss_aggregated = _dkg.aggregate_transcripts(transcripts)
if not pvss_aggregated.validate(_dkg):
raise Exception("validation failed") # TODO: better exception
public_key = _dkg.final_key
return pvss_aggregated, public_key
def derive_decryption_share(
aggregated_transcript: AggregatedTranscript,
keypair: Keypair,
ciphertext: bytes,
aad: bytes,
*args,
**kwargs
) -> DecryptionShare:
dkg = __make_dkg(*args, **kwargs)
assert aggregated_transcript.validate(dkg)
decryption_share = aggregated_transcript.create_decryption_share(
dkg, ciphertext, aad, keypair
)
return decryption_share

View File

@ -1,34 +0,0 @@
from typing import Tuple, Dict
import os
class DKGRitual:
domain: bytes
index: int
node: bytes
pvss_params: Dict[str, bytes]
def generate_dkg_blinding_keypair(*args, **kwargs) -> Tuple[bytes, bytes]:
"""ferveo generate_blinding_keypair"""
return os.urandom(32), os.urandom(32)
def generate_dkg_ritual(*args, **kwargs) ->:
"""ferveo generate_ritual"""
return DKGRitual(0)
def generate_dkg_transcript(dkg, shares: int, *args, **kwargs) -> bytes:
"""ferveo generate PVSS"""
return os.urandom(32)
def confirm_dkg_transcript(transcript, *args, **kwargs) -> bool:
"""ferveo confirm_transcript"""
return True
def compute_dfrag(*args, **kwargs):
"""ferveo compute_dfrag"""