updates mocks to sync with coordinator agent updates

remotes/origin/v7.4.x
KPrasch 2024-01-26 15:08:22 +01:00 committed by Derek Pierre
parent 9f8661c0bb
commit c1495b66a7
2 changed files with 21 additions and 26 deletions

View File

@ -1,6 +1,6 @@
import time
from enum import Enum
from typing import Dict, List, NamedTuple
from typing import Dict, List, NamedTuple, Optional
from eth_typing import ChecksumAddress
from eth_utils import keccak
@ -121,7 +121,7 @@ class MockCoordinatorAgent(MockContractAgent):
self._get_staking_provider_from_operator(operator=operator_address)
or transacting_power.account
)
participant = self.get_participant_from_provider(ritual_id, provider)
participant = self.get_participant(ritual_id, provider)
participant.transcript = bytes(transcript)
ritual.total_transcripts += 1
if ritual.total_transcripts == ritual.dkg_size:
@ -151,7 +151,7 @@ class MockCoordinatorAgent(MockContractAgent):
self._get_staking_provider_from_operator(operator=operator_address)
or transacting_power.account
)
participant = self.get_participant_from_provider(ritual_id, provider)
participant = self.get_participant(ritual_id, provider)
participant.aggregated = True
participant.decryption_request_static_key = bytes(participant_public_key)
@ -207,23 +207,26 @@ class MockCoordinatorAgent(MockContractAgent):
def number_of_rituals(self) -> int:
return len(self.rituals)
def get_ritual(
self, ritual_id: int, with_participants: bool = False
) -> CoordinatorAgent.Ritual:
def get_ritual(self, ritual_id: int) -> CoordinatorAgent.Ritual:
return self.rituals[ritual_id]
def get_participants(self, ritual_id: int) -> List[Participant]:
return self.rituals[ritual_id].participants
def is_participant(self, ritual_id: int, provider: ChecksumAddress) -> bool:
try:
self.get_participant(ritual_id, provider)
return True
except ValueError:
return False
def get_participant_from_provider(
self, ritual_id: int, provider: ChecksumAddress
) -> Participant:
def get_participant(self, ritual_id: int, provider: ChecksumAddress) -> Participant:
for p in self.rituals[ritual_id].participants:
if p.provider == provider:
return p
raise ValueError(f"Provider {provider} not found for ritual #{ritual_id}")
def get_providers(self, ritual_id: int) -> List[ChecksumAddress]:
return [p.provider for p in self.rituals[ritual_id].participants]
def get_ritual_status(self, ritual_id: int) -> int:
ritual = self.rituals[ritual_id]
timestamp = int(ritual.init_timestamp)
@ -259,7 +262,7 @@ class MockCoordinatorAgent(MockContractAgent):
f"No ritual id found for public key 0x{bytes(public_key).hex()}"
)
def get_ritual_public_key(self, ritual_id: int) -> DkgPublicKey:
def get_ritual_public_key(self, ritual_id: int) -> Optional[DkgPublicKey]:
status = self.get_ritual_status(ritual_id=ritual_id)
if status != self.Ritual.Status.ACTIVE and status != self.Ritual.Status.EXPIRED:
# TODO should we raise here instead?
@ -277,8 +280,8 @@ class MockCoordinatorAgent(MockContractAgent):
participant_keys = self._participant_keys_history[provider]
for participant_key in reversed(participant_keys):
if participant_key.lastRitualId <= ritual_id:
g2Point = participant_key.publicKey
return g2Point.to_public_key()
g2_point = participant_key.publicKey
return g2_point.to_public_key()
raise ValueError(
f"Public key not found for provider {provider} for ritual #{ritual_id}"

View File

@ -113,9 +113,7 @@ def test_perform_round_1(
)
agent.get_ritual = lambda *args, **kwargs: ritual
agent.get_participants = lambda *args, **kwargs: participants
agent.get_participant_from_provider = lambda ritual_id, provider: participants[
provider
]
agent.get_participant = lambda ritual_id, provider: participants[provider]
# ensure no operation performed for non-application-state
non_application_states = [
@ -156,9 +154,7 @@ def test_perform_round_1(
ursula.dkg_storage.store_transcript_receipt(ritual_id=0, txhash_or_receipt=None)
# participant already posted transcript
participant = agent.get_participant_from_provider(
ritual_id=0, provider=ursula.checksum_address
)
participant = agent.get_participant(ritual_id=0, provider=ursula.checksum_address)
participant.transcript = bytes(random_transcript)
# try submitting again
@ -213,9 +209,7 @@ def test_perform_round_2(
agent.get_ritual = lambda *args, **kwargs: ritual
agent.get_participants = lambda *args, **kwargs: participants
agent.get_participant_from_provider = lambda ritual_id, provider: participants[
provider
]
agent.get_participant = lambda ritual_id, provider: participants[provider]
# ensure no operation performed for non-application-state
non_application_states = [
@ -253,9 +247,7 @@ def test_perform_round_2(
)
# participant already posted aggregated transcript
participant = agent.get_participant_from_provider(
ritual_id=0, provider=ursula.checksum_address
)
participant = agent.get_participant(ritual_id=0, provider=ursula.checksum_address)
participant.aggregated = True
# try submitting again