mirror of https://github.com/nucypher/nucypher.git
handle shared coordinator reads and dkg artifact preperation logic between ursula and decryptors.
parent
8cd51eebf5
commit
96cac14fd5
|
@ -291,24 +291,6 @@ class Operator(BaseActor):
|
|||
|
||||
return providers
|
||||
|
||||
def get_ritual(
|
||||
self, ritual_id: int, download_transcripts: bool
|
||||
) -> CoordinatorAgent.Ritual:
|
||||
"""assembles Coordinator data using multiple RPC requests."""
|
||||
ritual = self.coordinator_agent.get_ritual(ritual_id) # call 1
|
||||
addresses = self.coordinator_agent.get_providers(ritual_id) # call 2
|
||||
participants = []
|
||||
for index, address in enumerate(addresses): # n calls
|
||||
if download_transcripts:
|
||||
participant = self.coordinator_agent.get_participant(ritual_id, address)
|
||||
else:
|
||||
participant = CoordinatorAgent.Ritual.Participant(
|
||||
index=index, provider=address
|
||||
)
|
||||
participants.append(participant)
|
||||
ritual.participants = participants
|
||||
return ritual
|
||||
|
||||
def _resolve_validators(
|
||||
self,
|
||||
ritual: CoordinatorAgent.Ritual,
|
||||
|
@ -398,7 +380,7 @@ class Operator(BaseActor):
|
|||
|
||||
# validate the active ritual tracker state
|
||||
participant = self.coordinator_agent.get_participant(
|
||||
ritual_id=ritual_id, provider=self.checksum_address
|
||||
ritual_id=ritual_id, provider=self.checksum_address, transcript=False
|
||||
)
|
||||
if participant.transcript:
|
||||
self.log.debug(
|
||||
|
@ -425,9 +407,9 @@ class Operator(BaseActor):
|
|||
f"performing round 1 of DKG ritual #{ritual_id} from blocktime {timestamp} with authority {authority}."
|
||||
)
|
||||
|
||||
# gather the ritual metadata
|
||||
ritual = self.get_ritual(ritual_id, download_transcripts=False)
|
||||
validators = list(zip(*self._resolve_validators(ritual, ritual_id)))
|
||||
# gather the ritual metadata and DKG artifacts
|
||||
ritual = self.coordinator_agent.get_ritual(ritual_id)
|
||||
validators = self._resolve_validators(ritual, ritual_id)
|
||||
|
||||
# generate a transcript
|
||||
try:
|
||||
|
@ -473,7 +455,7 @@ class Operator(BaseActor):
|
|||
|
||||
# validate the active ritual tracker state
|
||||
participant = self.coordinator_agent.get_participant(
|
||||
ritual_id=ritual_id, provider=self.checksum_address
|
||||
ritual_id=ritual_id, provider=self.checksum_address, transcript=False
|
||||
)
|
||||
if participant.aggregated:
|
||||
self.log.debug(
|
||||
|
@ -496,13 +478,18 @@ class Operator(BaseActor):
|
|||
f"{self.transacting_power.account[:8]} performing round 2 of DKG ritual #{ritual_id} from blocktime {timestamp}"
|
||||
)
|
||||
|
||||
ritual = self.coordinator_agent.get_ritual(ritual_id, with_participants=True)
|
||||
nodes, transcripts = self._resolve_validators(ritual, ritual_id)
|
||||
if not all(transcripts):
|
||||
ritual = self.coordinator_agent.get_ritual(ritual_id, transcripts=True)
|
||||
missing = sum(1 for t in ritual.transcripts if not t)
|
||||
if missing:
|
||||
raise self.ActorError(
|
||||
f"ritual #{ritual_id} is missing transcripts from {len([t for t in transcripts if not t])} nodes."
|
||||
f"ritual #{ritual_id} is missing transcripts from {missing} nodes."
|
||||
)
|
||||
|
||||
# Prepare the DKG artifacts
|
||||
validators = self._resolve_validators(ritual, ritual_id)
|
||||
transcripts = (Transcript.from_bytes(bytes(t)) for t in ritual.transcripts)
|
||||
messages = list(zip(validators, transcripts))
|
||||
|
||||
# Aggregate the transcripts
|
||||
try:
|
||||
result = self.ritual_power.aggregate_transcripts(
|
||||
|
@ -510,7 +497,7 @@ class Operator(BaseActor):
|
|||
shares=ritual.shares,
|
||||
checksum_address=self.checksum_address,
|
||||
ritual_id=ritual_id,
|
||||
transcripts=transcripts
|
||||
transcripts=messages,
|
||||
)
|
||||
except Exception as e:
|
||||
self.log.debug(f"Failed to aggregate transcripts for ritual #{ritual_id}: {str(e)}")
|
||||
|
@ -558,16 +545,10 @@ class Operator(BaseActor):
|
|||
ritual = self.coordinator_agent.get_ritual(ritual_id)
|
||||
if not self.coordinator_agent.is_ritual_active(ritual_id=ritual_id):
|
||||
raise self.ActorError(f"Ritual #{ritual_id} is not active.")
|
||||
|
||||
nodes, transcripts = list(zip(*self._resolve_validators(ritual, ritual_id)))
|
||||
if not all(transcripts):
|
||||
raise self.ActorError(f"Ritual #{ritual_id} is missing transcripts")
|
||||
|
||||
# TODO: consider the usage of local DKG artifact storage here #3052
|
||||
# aggregated_transcript_bytes = self.dkg_storage.get_aggregated_transcript(ritual_id)
|
||||
validators = list(self._resolve_validators(ritual, ritual_id))
|
||||
aggregated_transcript = AggregatedTranscript.from_bytes(bytes(ritual.aggregated_transcript))
|
||||
decryption_share = self.ritual_power.derive_decryption_share(
|
||||
nodes=nodes,
|
||||
nodes=validators,
|
||||
threshold=ritual.threshold,
|
||||
shares=ritual.shares,
|
||||
checksum_address=self.checksum_address,
|
||||
|
@ -577,7 +558,6 @@ class Operator(BaseActor):
|
|||
aad=aad,
|
||||
variant=variant
|
||||
)
|
||||
|
||||
return decryption_share
|
||||
|
||||
def decrypt_threshold_decryption_request(
|
||||
|
@ -602,7 +582,7 @@ class Operator(BaseActor):
|
|||
# dkg_public_key = this_node.dkg_storage.get_public_key(decryption_request.ritual_id)
|
||||
|
||||
# enforces that the node is part of the ritual
|
||||
participating = self.coordinator_agent.is_provider_participating(
|
||||
participating = self.coordinator_agent.is_participant(
|
||||
ritual_id=decryption_request.ritual_id, provider=self.checksum_address
|
||||
)
|
||||
if not participating:
|
||||
|
|
|
@ -723,11 +723,8 @@ class CoordinatorAgent(EthereumContractAgent):
|
|||
return [p.provider for p in self.participants]
|
||||
|
||||
@property
|
||||
def transcripts(self) -> List[Tuple[ChecksumAddress, bytes]]:
|
||||
transcripts = list()
|
||||
for p in self.participants:
|
||||
transcripts.append((p.provider, p.transcript))
|
||||
return transcripts
|
||||
def transcripts(self) -> Iterable[bytes]:
|
||||
return [p.transcript for p in self.participants]
|
||||
|
||||
@property
|
||||
def shares(self) -> int:
|
||||
|
@ -748,7 +745,7 @@ class CoordinatorAgent(EthereumContractAgent):
|
|||
return self.contract.functions.timeout().call()
|
||||
|
||||
@contract_api(CONTRACT_CALL)
|
||||
def get_ritual(self, ritual_id: int) -> Ritual:
|
||||
def rituals(self, ritual_id: int) -> Ritual:
|
||||
result = self.contract.functions.rituals(int(ritual_id)).call()
|
||||
ritual = self.Ritual(
|
||||
initiator=ChecksumAddress(result[0]),
|
||||
|
@ -764,11 +761,32 @@ class CoordinatorAgent(EthereumContractAgent):
|
|||
aggregated_transcript=bytes(result[11]),
|
||||
participants=[], # solidity does not return sub-structs
|
||||
)
|
||||
|
||||
# public key
|
||||
ritual.public_key = self.Ritual.G1Point(result[10][0], result[10][1])
|
||||
return ritual
|
||||
|
||||
def get_ritual(
|
||||
self,
|
||||
ritual_id: int,
|
||||
transcripts: bool = False,
|
||||
participants: bool = True,
|
||||
) -> Ritual:
|
||||
"""assembles Coordinator data using multiple RPC requests."""
|
||||
if not participants and transcripts:
|
||||
raise ValueError("Cannot get transcripts without participants")
|
||||
|
||||
ritual = self.rituals(ritual_id) # call 1
|
||||
if not participants:
|
||||
return ritual
|
||||
|
||||
participants = []
|
||||
addresses = self.get_providers(ritual_id) # call 2
|
||||
for index, address in enumerate(addresses): # n calls
|
||||
participant = self.get_participant(ritual_id, address, transcripts)
|
||||
participants.append(participant)
|
||||
ritual.participants = participants
|
||||
return ritual
|
||||
|
||||
@contract_api(CONTRACT_CALL)
|
||||
def get_ritual_status(self, ritual_id: int) -> int:
|
||||
result = self.contract.functions.getRitualState(ritual_id).call()
|
||||
|
@ -791,9 +809,11 @@ class CoordinatorAgent(EthereumContractAgent):
|
|||
|
||||
@contract_api(CONTRACT_CALL)
|
||||
def get_participant(
|
||||
self, ritual_id: int, provider: ChecksumAddress
|
||||
self, ritual_id: int, provider: ChecksumAddress, transcript: bool
|
||||
) -> Ritual.Participant:
|
||||
data, index = self.contract.functions.getParticipant(ritual_id, provider).call()
|
||||
data, index = self.contract.functions.getParticipant(
|
||||
ritual_id, provider, transcript
|
||||
).call()
|
||||
participant = self.Ritual.Participant(
|
||||
index=index,
|
||||
provider=ChecksumAddress(data[0]),
|
||||
|
@ -805,7 +825,7 @@ class CoordinatorAgent(EthereumContractAgent):
|
|||
|
||||
@contract_api(CONTRACT_CALL)
|
||||
def get_provider_public_key(
|
||||
self, provider: ChecksumAddress, ritual_id: int
|
||||
self, provider: ChecksumAddress, ritual_id: int
|
||||
) -> FerveoPublicKey:
|
||||
result = self.contract.functions.getProviderPublicKey(
|
||||
provider, ritual_id
|
||||
|
|
|
@ -246,7 +246,7 @@ class ActiveRitualTracker:
|
|||
Returns node's participant information for the provided
|
||||
ritual id; None if node is not participating in the ritual
|
||||
"""
|
||||
participants = self.coordinator_agent.get_participants(ritual_id=ritual_id)
|
||||
participants = self.coordinator_agent.get_ritual(ritual_id=ritual_id)
|
||||
for p in participants:
|
||||
if p.provider == self.operator.checksum_address:
|
||||
return p
|
||||
|
|
|
@ -694,8 +694,9 @@ class Bob(Character):
|
|||
)
|
||||
return ritual_id
|
||||
|
||||
def get_ritual_from_id(self, ritual_id) -> CoordinatorAgent.Ritual:
|
||||
ritual = self._get_coordinator_agent().get_ritual(ritual_id)
|
||||
def get_ritual(self, ritual_id) -> CoordinatorAgent.Ritual:
|
||||
agent = self._get_coordinator_agent()
|
||||
ritual = agent.get_ritual(ritual_id)
|
||||
return ritual
|
||||
|
||||
def threshold_decrypt(
|
||||
|
@ -708,7 +709,7 @@ class Bob(Character):
|
|||
ritual_id = self.get_ritual_id_from_public_key(
|
||||
public_key=threshold_message_kit.acp.public_key
|
||||
)
|
||||
ritual = self.get_ritual_from_id(ritual_id=ritual_id)
|
||||
ritual = self.get_ritual(ritual_id=ritual_id)
|
||||
|
||||
if ursulas:
|
||||
for ursula in ursulas:
|
||||
|
@ -726,10 +727,9 @@ class Bob(Character):
|
|||
variant=variant,
|
||||
context=context,
|
||||
)
|
||||
participant_public_keys = ritual.participant_public_keys
|
||||
decryption_shares = self._get_decryption_shares(
|
||||
decryption_request=decryption_request,
|
||||
participant_public_keys=participant_public_keys,
|
||||
participant_public_keys=ritual.participant_public_keys,
|
||||
threshold=ritual.threshold,
|
||||
timeout=decryption_timeout,
|
||||
)
|
||||
|
|
|
@ -166,7 +166,9 @@ def test_ursula_ritualist(
|
|||
assert (
|
||||
len(
|
||||
coordinator_agent.get_participant(
|
||||
ritual_id=RITUAL_ID, provider=ursula.checksum_address
|
||||
ritual_id=RITUAL_ID,
|
||||
provider=ursula.checksum_address,
|
||||
transcript=True,
|
||||
).transcript
|
||||
)
|
||||
> 0
|
||||
|
|
|
@ -99,8 +99,8 @@ def test_initiate_ritual(
|
|||
ritual = agent.get_ritual(ritual_id)
|
||||
assert ritual.authority == authority
|
||||
|
||||
participants = agent.get_participants(ritual_id)
|
||||
assert [p.provider for p in participants] == cohort
|
||||
ritual = agent.get_ritual(ritual_id)
|
||||
assert [p.provider for p in ritual.participants] == cohort
|
||||
|
||||
assert (
|
||||
agent.get_ritual_status(ritual_id=ritual_id)
|
||||
|
@ -128,8 +128,8 @@ def test_post_transcript(agent, transcripts, transacting_powers):
|
|||
assert event["args"]["ritualId"] == ritual_id
|
||||
assert event["args"]["transcriptDigest"] == keccak(transcripts[i])
|
||||
|
||||
participants = agent.get_participants(ritual_id)
|
||||
assert [p.transcript for p in participants] == transcripts
|
||||
ritual = agent.get_ritual(ritual_id, transcripts=True)
|
||||
assert [p.transcript for p in ritual.participants] == transcripts
|
||||
|
||||
assert (
|
||||
agent.get_ritual_status(ritual_id=ritual_id)
|
||||
|
@ -167,7 +167,7 @@ def test_post_aggregation(
|
|||
bytes(aggregated_transcript)
|
||||
)
|
||||
|
||||
participants = agent.get_participants(ritual_id)
|
||||
participants = agent.get_ritual(ritual_id).participants
|
||||
for p in participants:
|
||||
assert p.aggregated
|
||||
assert p.decryption_request_static_key == bytes(
|
||||
|
|
|
@ -243,7 +243,6 @@ def test_first_scan_start_block_calc_is_not_perfect_go_back_more_blocks(ritualis
|
|||
|
||||
|
||||
def test_get_ritual_participant_info(ritualist, get_random_checksum_address):
|
||||
mocked_agent = ritualist.coordinator_agent
|
||||
active_ritual_tracker = ActiveRitualTracker(operator=ritualist)
|
||||
|
||||
participants = []
|
||||
|
@ -253,7 +252,6 @@ def test_get_ritual_participant_info(ritualist, get_random_checksum_address):
|
|||
index=i, provider=get_random_checksum_address()
|
||||
)
|
||||
participants.append(participant)
|
||||
mocked_agent.get_participants.return_value = participants
|
||||
|
||||
# operator not in participants list
|
||||
participant_info = active_ritual_tracker._get_ritual_participant_info(ritual_id=0)
|
||||
|
@ -274,7 +272,6 @@ def test_get_ritual_participant_info(ritualist, get_random_checksum_address):
|
|||
def test_get_participation_state_values_from_contract(
|
||||
ritualist, get_random_checksum_address
|
||||
):
|
||||
mocked_agent = ritualist.coordinator_agent
|
||||
active_ritual_tracker = ActiveRitualTracker(operator=ritualist)
|
||||
|
||||
participants = []
|
||||
|
@ -285,8 +282,6 @@ def test_get_participation_state_values_from_contract(
|
|||
)
|
||||
participants.append(participant)
|
||||
|
||||
mocked_agent.get_participants.return_value = participants
|
||||
|
||||
# not participating so everything should be False
|
||||
(
|
||||
participating,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import time
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Dict, List, NamedTuple, Optional
|
||||
|
||||
|
@ -121,7 +122,7 @@ class MockCoordinatorAgent(MockContractAgent):
|
|||
self._get_staking_provider_from_operator(operator=operator_address)
|
||||
or transacting_power.account
|
||||
)
|
||||
participant = self.get_participant(ritual_id, provider)
|
||||
participant = self.get_participant(ritual_id, provider, False)
|
||||
participant.transcript = bytes(transcript)
|
||||
ritual.total_transcripts += 1
|
||||
if ritual.total_transcripts == ritual.dkg_size:
|
||||
|
@ -151,7 +152,7 @@ class MockCoordinatorAgent(MockContractAgent):
|
|||
self._get_staking_provider_from_operator(operator=operator_address)
|
||||
or transacting_power.account
|
||||
)
|
||||
participant = self.get_participant(ritual_id, provider)
|
||||
participant = self.get_participant(ritual_id, provider, True)
|
||||
participant.aggregated = True
|
||||
participant.decryption_request_static_key = bytes(participant_public_key)
|
||||
|
||||
|
@ -207,19 +208,27 @@ class MockCoordinatorAgent(MockContractAgent):
|
|||
def number_of_rituals(self) -> int:
|
||||
return len(self.rituals)
|
||||
|
||||
def get_ritual(self, ritual_id: int) -> CoordinatorAgent.Ritual:
|
||||
return self.rituals[ritual_id]
|
||||
def get_ritual(
|
||||
self, ritual_id: int, transcripts: bool = False, participants: bool = True
|
||||
) -> CoordinatorAgent.Ritual:
|
||||
ritual = deepcopy(self.rituals[ritual_id])
|
||||
return ritual
|
||||
|
||||
def is_participant(self, ritual_id: int, provider: ChecksumAddress) -> bool:
|
||||
try:
|
||||
self.get_participant(ritual_id, provider)
|
||||
self.get_participant(ritual_id, provider, False)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def get_participant(self, ritual_id: int, provider: ChecksumAddress) -> Participant:
|
||||
def get_participant(
|
||||
self, ritual_id: int, provider: ChecksumAddress, transcript: bool
|
||||
) -> Participant:
|
||||
for p in self.rituals[ritual_id].participants:
|
||||
if p.provider == provider:
|
||||
# if not transcripts:
|
||||
# p = deepcopy(p)
|
||||
# p.transcript = b""
|
||||
return p
|
||||
|
||||
raise ValueError(f"Provider {provider} not found for ritual #{ritual_id}")
|
||||
|
|
|
@ -75,7 +75,6 @@ def test_initiate_ritual(
|
|||
participants=participants,
|
||||
)
|
||||
agent.get_ritual = lambda *args, **kwargs: ritual
|
||||
agent.get_participants = lambda *args, **kwargs: participants
|
||||
|
||||
assert receipt["transactionHash"]
|
||||
number_of_rituals = agent.number_of_rituals()
|
||||
|
@ -112,8 +111,9 @@ def test_perform_round_1(
|
|||
participants=list(participants.values()),
|
||||
)
|
||||
agent.get_ritual = lambda *args, **kwargs: ritual
|
||||
agent.get_participants = lambda *args, **kwargs: participants
|
||||
agent.get_participant = lambda ritual_id, provider: participants[provider]
|
||||
agent.get_participant = lambda ritual_id, provider, transcripts: participants[
|
||||
provider
|
||||
]
|
||||
|
||||
# ensure no operation performed for non-application-state
|
||||
non_application_states = [
|
||||
|
@ -154,7 +154,9 @@ def test_perform_round_1(
|
|||
ursula.dkg_storage.store_transcript_receipt(ritual_id=0, txhash_or_receipt=None)
|
||||
|
||||
# participant already posted transcript
|
||||
participant = agent.get_participant(ritual_id=0, provider=ursula.checksum_address)
|
||||
participant = agent.get_participant(
|
||||
ritual_id=0, provider=ursula.checksum_address, transcript=False
|
||||
)
|
||||
participant.transcript = bytes(random_transcript)
|
||||
|
||||
# try submitting again
|
||||
|
@ -208,7 +210,6 @@ def test_perform_round_2(
|
|||
)
|
||||
|
||||
agent.get_ritual = lambda *args, **kwargs: ritual
|
||||
agent.get_participants = lambda *args, **kwargs: participants
|
||||
agent.get_participant = lambda ritual_id, provider: participants[provider]
|
||||
|
||||
# ensure no operation performed for non-application-state
|
||||
|
|
Loading…
Reference in New Issue