handle shared coordinator reads and dkg artifact preperation logic between ursula and decryptors.

remotes/origin/v7.4.x
KPrasch 2024-01-26 21:26:56 +01:00 committed by Derek Pierre
parent 8cd51eebf5
commit 96cac14fd5
9 changed files with 83 additions and 76 deletions

View File

@ -291,24 +291,6 @@ class Operator(BaseActor):
return providers
def get_ritual(
self, ritual_id: int, download_transcripts: bool
) -> CoordinatorAgent.Ritual:
"""assembles Coordinator data using multiple RPC requests."""
ritual = self.coordinator_agent.get_ritual(ritual_id) # call 1
addresses = self.coordinator_agent.get_providers(ritual_id) # call 2
participants = []
for index, address in enumerate(addresses): # n calls
if download_transcripts:
participant = self.coordinator_agent.get_participant(ritual_id, address)
else:
participant = CoordinatorAgent.Ritual.Participant(
index=index, provider=address
)
participants.append(participant)
ritual.participants = participants
return ritual
def _resolve_validators(
self,
ritual: CoordinatorAgent.Ritual,
@ -398,7 +380,7 @@ class Operator(BaseActor):
# validate the active ritual tracker state
participant = self.coordinator_agent.get_participant(
ritual_id=ritual_id, provider=self.checksum_address
ritual_id=ritual_id, provider=self.checksum_address, transcript=False
)
if participant.transcript:
self.log.debug(
@ -425,9 +407,9 @@ class Operator(BaseActor):
f"performing round 1 of DKG ritual #{ritual_id} from blocktime {timestamp} with authority {authority}."
)
# gather the ritual metadata
ritual = self.get_ritual(ritual_id, download_transcripts=False)
validators = list(zip(*self._resolve_validators(ritual, ritual_id)))
# gather the ritual metadata and DKG artifacts
ritual = self.coordinator_agent.get_ritual(ritual_id)
validators = self._resolve_validators(ritual, ritual_id)
# generate a transcript
try:
@ -473,7 +455,7 @@ class Operator(BaseActor):
# validate the active ritual tracker state
participant = self.coordinator_agent.get_participant(
ritual_id=ritual_id, provider=self.checksum_address
ritual_id=ritual_id, provider=self.checksum_address, transcript=False
)
if participant.aggregated:
self.log.debug(
@ -496,13 +478,18 @@ class Operator(BaseActor):
f"{self.transacting_power.account[:8]} performing round 2 of DKG ritual #{ritual_id} from blocktime {timestamp}"
)
ritual = self.coordinator_agent.get_ritual(ritual_id, with_participants=True)
nodes, transcripts = self._resolve_validators(ritual, ritual_id)
if not all(transcripts):
ritual = self.coordinator_agent.get_ritual(ritual_id, transcripts=True)
missing = sum(1 for t in ritual.transcripts if not t)
if missing:
raise self.ActorError(
f"ritual #{ritual_id} is missing transcripts from {len([t for t in transcripts if not t])} nodes."
f"ritual #{ritual_id} is missing transcripts from {missing} nodes."
)
# Prepare the DKG artifacts
validators = self._resolve_validators(ritual, ritual_id)
transcripts = (Transcript.from_bytes(bytes(t)) for t in ritual.transcripts)
messages = list(zip(validators, transcripts))
# Aggregate the transcripts
try:
result = self.ritual_power.aggregate_transcripts(
@ -510,7 +497,7 @@ class Operator(BaseActor):
shares=ritual.shares,
checksum_address=self.checksum_address,
ritual_id=ritual_id,
transcripts=transcripts
transcripts=messages,
)
except Exception as e:
self.log.debug(f"Failed to aggregate transcripts for ritual #{ritual_id}: {str(e)}")
@ -558,16 +545,10 @@ class Operator(BaseActor):
ritual = self.coordinator_agent.get_ritual(ritual_id)
if not self.coordinator_agent.is_ritual_active(ritual_id=ritual_id):
raise self.ActorError(f"Ritual #{ritual_id} is not active.")
nodes, transcripts = list(zip(*self._resolve_validators(ritual, ritual_id)))
if not all(transcripts):
raise self.ActorError(f"Ritual #{ritual_id} is missing transcripts")
# TODO: consider the usage of local DKG artifact storage here #3052
# aggregated_transcript_bytes = self.dkg_storage.get_aggregated_transcript(ritual_id)
validators = list(self._resolve_validators(ritual, ritual_id))
aggregated_transcript = AggregatedTranscript.from_bytes(bytes(ritual.aggregated_transcript))
decryption_share = self.ritual_power.derive_decryption_share(
nodes=nodes,
nodes=validators,
threshold=ritual.threshold,
shares=ritual.shares,
checksum_address=self.checksum_address,
@ -577,7 +558,6 @@ class Operator(BaseActor):
aad=aad,
variant=variant
)
return decryption_share
def decrypt_threshold_decryption_request(
@ -602,7 +582,7 @@ class Operator(BaseActor):
# dkg_public_key = this_node.dkg_storage.get_public_key(decryption_request.ritual_id)
# enforces that the node is part of the ritual
participating = self.coordinator_agent.is_provider_participating(
participating = self.coordinator_agent.is_participant(
ritual_id=decryption_request.ritual_id, provider=self.checksum_address
)
if not participating:

View File

@ -723,11 +723,8 @@ class CoordinatorAgent(EthereumContractAgent):
return [p.provider for p in self.participants]
@property
def transcripts(self) -> List[Tuple[ChecksumAddress, bytes]]:
transcripts = list()
for p in self.participants:
transcripts.append((p.provider, p.transcript))
return transcripts
def transcripts(self) -> Iterable[bytes]:
return [p.transcript for p in self.participants]
@property
def shares(self) -> int:
@ -748,7 +745,7 @@ class CoordinatorAgent(EthereumContractAgent):
return self.contract.functions.timeout().call()
@contract_api(CONTRACT_CALL)
def get_ritual(self, ritual_id: int) -> Ritual:
def rituals(self, ritual_id: int) -> Ritual:
result = self.contract.functions.rituals(int(ritual_id)).call()
ritual = self.Ritual(
initiator=ChecksumAddress(result[0]),
@ -764,11 +761,32 @@ class CoordinatorAgent(EthereumContractAgent):
aggregated_transcript=bytes(result[11]),
participants=[], # solidity does not return sub-structs
)
# public key
ritual.public_key = self.Ritual.G1Point(result[10][0], result[10][1])
return ritual
def get_ritual(
self,
ritual_id: int,
transcripts: bool = False,
participants: bool = True,
) -> Ritual:
"""assembles Coordinator data using multiple RPC requests."""
if not participants and transcripts:
raise ValueError("Cannot get transcripts without participants")
ritual = self.rituals(ritual_id) # call 1
if not participants:
return ritual
participants = []
addresses = self.get_providers(ritual_id) # call 2
for index, address in enumerate(addresses): # n calls
participant = self.get_participant(ritual_id, address, transcripts)
participants.append(participant)
ritual.participants = participants
return ritual
@contract_api(CONTRACT_CALL)
def get_ritual_status(self, ritual_id: int) -> int:
result = self.contract.functions.getRitualState(ritual_id).call()
@ -791,9 +809,11 @@ class CoordinatorAgent(EthereumContractAgent):
@contract_api(CONTRACT_CALL)
def get_participant(
self, ritual_id: int, provider: ChecksumAddress
self, ritual_id: int, provider: ChecksumAddress, transcript: bool
) -> Ritual.Participant:
data, index = self.contract.functions.getParticipant(ritual_id, provider).call()
data, index = self.contract.functions.getParticipant(
ritual_id, provider, transcript
).call()
participant = self.Ritual.Participant(
index=index,
provider=ChecksumAddress(data[0]),
@ -805,7 +825,7 @@ class CoordinatorAgent(EthereumContractAgent):
@contract_api(CONTRACT_CALL)
def get_provider_public_key(
self, provider: ChecksumAddress, ritual_id: int
self, provider: ChecksumAddress, ritual_id: int
) -> FerveoPublicKey:
result = self.contract.functions.getProviderPublicKey(
provider, ritual_id

View File

@ -246,7 +246,7 @@ class ActiveRitualTracker:
Returns node's participant information for the provided
ritual id; None if node is not participating in the ritual
"""
participants = self.coordinator_agent.get_participants(ritual_id=ritual_id)
participants = self.coordinator_agent.get_ritual(ritual_id=ritual_id)
for p in participants:
if p.provider == self.operator.checksum_address:
return p

View File

@ -694,8 +694,9 @@ class Bob(Character):
)
return ritual_id
def get_ritual_from_id(self, ritual_id) -> CoordinatorAgent.Ritual:
ritual = self._get_coordinator_agent().get_ritual(ritual_id)
def get_ritual(self, ritual_id) -> CoordinatorAgent.Ritual:
agent = self._get_coordinator_agent()
ritual = agent.get_ritual(ritual_id)
return ritual
def threshold_decrypt(
@ -708,7 +709,7 @@ class Bob(Character):
ritual_id = self.get_ritual_id_from_public_key(
public_key=threshold_message_kit.acp.public_key
)
ritual = self.get_ritual_from_id(ritual_id=ritual_id)
ritual = self.get_ritual(ritual_id=ritual_id)
if ursulas:
for ursula in ursulas:
@ -726,10 +727,9 @@ class Bob(Character):
variant=variant,
context=context,
)
participant_public_keys = ritual.participant_public_keys
decryption_shares = self._get_decryption_shares(
decryption_request=decryption_request,
participant_public_keys=participant_public_keys,
participant_public_keys=ritual.participant_public_keys,
threshold=ritual.threshold,
timeout=decryption_timeout,
)

View File

@ -166,7 +166,9 @@ def test_ursula_ritualist(
assert (
len(
coordinator_agent.get_participant(
ritual_id=RITUAL_ID, provider=ursula.checksum_address
ritual_id=RITUAL_ID,
provider=ursula.checksum_address,
transcript=True,
).transcript
)
> 0

View File

@ -99,8 +99,8 @@ def test_initiate_ritual(
ritual = agent.get_ritual(ritual_id)
assert ritual.authority == authority
participants = agent.get_participants(ritual_id)
assert [p.provider for p in participants] == cohort
ritual = agent.get_ritual(ritual_id)
assert [p.provider for p in ritual.participants] == cohort
assert (
agent.get_ritual_status(ritual_id=ritual_id)
@ -128,8 +128,8 @@ def test_post_transcript(agent, transcripts, transacting_powers):
assert event["args"]["ritualId"] == ritual_id
assert event["args"]["transcriptDigest"] == keccak(transcripts[i])
participants = agent.get_participants(ritual_id)
assert [p.transcript for p in participants] == transcripts
ritual = agent.get_ritual(ritual_id, transcripts=True)
assert [p.transcript for p in ritual.participants] == transcripts
assert (
agent.get_ritual_status(ritual_id=ritual_id)
@ -167,7 +167,7 @@ def test_post_aggregation(
bytes(aggregated_transcript)
)
participants = agent.get_participants(ritual_id)
participants = agent.get_ritual(ritual_id).participants
for p in participants:
assert p.aggregated
assert p.decryption_request_static_key == bytes(

View File

@ -243,7 +243,6 @@ def test_first_scan_start_block_calc_is_not_perfect_go_back_more_blocks(ritualis
def test_get_ritual_participant_info(ritualist, get_random_checksum_address):
mocked_agent = ritualist.coordinator_agent
active_ritual_tracker = ActiveRitualTracker(operator=ritualist)
participants = []
@ -253,7 +252,6 @@ def test_get_ritual_participant_info(ritualist, get_random_checksum_address):
index=i, provider=get_random_checksum_address()
)
participants.append(participant)
mocked_agent.get_participants.return_value = participants
# operator not in participants list
participant_info = active_ritual_tracker._get_ritual_participant_info(ritual_id=0)
@ -274,7 +272,6 @@ def test_get_ritual_participant_info(ritualist, get_random_checksum_address):
def test_get_participation_state_values_from_contract(
ritualist, get_random_checksum_address
):
mocked_agent = ritualist.coordinator_agent
active_ritual_tracker = ActiveRitualTracker(operator=ritualist)
participants = []
@ -285,8 +282,6 @@ def test_get_participation_state_values_from_contract(
)
participants.append(participant)
mocked_agent.get_participants.return_value = participants
# not participating so everything should be False
(
participating,

View File

@ -1,4 +1,5 @@
import time
from copy import deepcopy
from enum import Enum
from typing import Dict, List, NamedTuple, Optional
@ -121,7 +122,7 @@ class MockCoordinatorAgent(MockContractAgent):
self._get_staking_provider_from_operator(operator=operator_address)
or transacting_power.account
)
participant = self.get_participant(ritual_id, provider)
participant = self.get_participant(ritual_id, provider, False)
participant.transcript = bytes(transcript)
ritual.total_transcripts += 1
if ritual.total_transcripts == ritual.dkg_size:
@ -151,7 +152,7 @@ class MockCoordinatorAgent(MockContractAgent):
self._get_staking_provider_from_operator(operator=operator_address)
or transacting_power.account
)
participant = self.get_participant(ritual_id, provider)
participant = self.get_participant(ritual_id, provider, True)
participant.aggregated = True
participant.decryption_request_static_key = bytes(participant_public_key)
@ -207,19 +208,27 @@ class MockCoordinatorAgent(MockContractAgent):
def number_of_rituals(self) -> int:
return len(self.rituals)
def get_ritual(self, ritual_id: int) -> CoordinatorAgent.Ritual:
return self.rituals[ritual_id]
def get_ritual(
self, ritual_id: int, transcripts: bool = False, participants: bool = True
) -> CoordinatorAgent.Ritual:
ritual = deepcopy(self.rituals[ritual_id])
return ritual
def is_participant(self, ritual_id: int, provider: ChecksumAddress) -> bool:
try:
self.get_participant(ritual_id, provider)
self.get_participant(ritual_id, provider, False)
return True
except ValueError:
return False
def get_participant(self, ritual_id: int, provider: ChecksumAddress) -> Participant:
def get_participant(
self, ritual_id: int, provider: ChecksumAddress, transcript: bool
) -> Participant:
for p in self.rituals[ritual_id].participants:
if p.provider == provider:
# if not transcripts:
# p = deepcopy(p)
# p.transcript = b""
return p
raise ValueError(f"Provider {provider} not found for ritual #{ritual_id}")

View File

@ -75,7 +75,6 @@ def test_initiate_ritual(
participants=participants,
)
agent.get_ritual = lambda *args, **kwargs: ritual
agent.get_participants = lambda *args, **kwargs: participants
assert receipt["transactionHash"]
number_of_rituals = agent.number_of_rituals()
@ -112,8 +111,9 @@ def test_perform_round_1(
participants=list(participants.values()),
)
agent.get_ritual = lambda *args, **kwargs: ritual
agent.get_participants = lambda *args, **kwargs: participants
agent.get_participant = lambda ritual_id, provider: participants[provider]
agent.get_participant = lambda ritual_id, provider, transcripts: participants[
provider
]
# ensure no operation performed for non-application-state
non_application_states = [
@ -154,7 +154,9 @@ def test_perform_round_1(
ursula.dkg_storage.store_transcript_receipt(ritual_id=0, txhash_or_receipt=None)
# participant already posted transcript
participant = agent.get_participant(ritual_id=0, provider=ursula.checksum_address)
participant = agent.get_participant(
ritual_id=0, provider=ursula.checksum_address, transcript=False
)
participant.transcript = bytes(random_transcript)
# try submitting again
@ -208,7 +210,6 @@ def test_perform_round_2(
)
agent.get_ritual = lambda *args, **kwargs: ritual
agent.get_participants = lambda *args, **kwargs: participants
agent.get_participant = lambda ritual_id, provider: participants[provider]
# ensure no operation performed for non-application-state