diff --git a/tests/mock/coordinator.py b/tests/mock/coordinator.py index c6422205c..674b25149 100644 --- a/tests/mock/coordinator.py +++ b/tests/mock/coordinator.py @@ -1,6 +1,6 @@ import time from enum import Enum -from typing import Dict, List, NamedTuple +from typing import Dict, List, NamedTuple, Optional from eth_typing import ChecksumAddress from eth_utils import keccak @@ -121,7 +121,7 @@ class MockCoordinatorAgent(MockContractAgent): self._get_staking_provider_from_operator(operator=operator_address) or transacting_power.account ) - participant = self.get_participant_from_provider(ritual_id, provider) + participant = self.get_participant(ritual_id, provider) participant.transcript = bytes(transcript) ritual.total_transcripts += 1 if ritual.total_transcripts == ritual.dkg_size: @@ -151,7 +151,7 @@ class MockCoordinatorAgent(MockContractAgent): self._get_staking_provider_from_operator(operator=operator_address) or transacting_power.account ) - participant = self.get_participant_from_provider(ritual_id, provider) + participant = self.get_participant(ritual_id, provider) participant.aggregated = True participant.decryption_request_static_key = bytes(participant_public_key) @@ -207,23 +207,26 @@ class MockCoordinatorAgent(MockContractAgent): def number_of_rituals(self) -> int: return len(self.rituals) - def get_ritual( - self, ritual_id: int, with_participants: bool = False - ) -> CoordinatorAgent.Ritual: + def get_ritual(self, ritual_id: int) -> CoordinatorAgent.Ritual: return self.rituals[ritual_id] - def get_participants(self, ritual_id: int) -> List[Participant]: - return self.rituals[ritual_id].participants + def is_participant(self, ritual_id: int, provider: ChecksumAddress) -> bool: + try: + self.get_participant(ritual_id, provider) + return True + except ValueError: + return False - def get_participant_from_provider( - self, ritual_id: int, provider: ChecksumAddress - ) -> Participant: + def get_participant(self, ritual_id: int, provider: ChecksumAddress) -> Participant: for p in self.rituals[ritual_id].participants: if p.provider == provider: return p raise ValueError(f"Provider {provider} not found for ritual #{ritual_id}") + def get_providers(self, ritual_id: int) -> List[ChecksumAddress]: + return [p.provider for p in self.rituals[ritual_id].participants] + def get_ritual_status(self, ritual_id: int) -> int: ritual = self.rituals[ritual_id] timestamp = int(ritual.init_timestamp) @@ -259,7 +262,7 @@ class MockCoordinatorAgent(MockContractAgent): f"No ritual id found for public key 0x{bytes(public_key).hex()}" ) - def get_ritual_public_key(self, ritual_id: int) -> DkgPublicKey: + def get_ritual_public_key(self, ritual_id: int) -> Optional[DkgPublicKey]: status = self.get_ritual_status(ritual_id=ritual_id) if status != self.Ritual.Status.ACTIVE and status != self.Ritual.Status.EXPIRED: # TODO should we raise here instead? @@ -277,8 +280,8 @@ class MockCoordinatorAgent(MockContractAgent): participant_keys = self._participant_keys_history[provider] for participant_key in reversed(participant_keys): if participant_key.lastRitualId <= ritual_id: - g2Point = participant_key.publicKey - return g2Point.to_public_key() + g2_point = participant_key.publicKey + return g2_point.to_public_key() raise ValueError( f"Public key not found for provider {provider} for ritual #{ritual_id}" diff --git a/tests/unit/test_ritualist.py b/tests/unit/test_ritualist.py index 1789de16a..771dec965 100644 --- a/tests/unit/test_ritualist.py +++ b/tests/unit/test_ritualist.py @@ -113,9 +113,7 @@ def test_perform_round_1( ) agent.get_ritual = lambda *args, **kwargs: ritual agent.get_participants = lambda *args, **kwargs: participants - agent.get_participant_from_provider = lambda ritual_id, provider: participants[ - provider - ] + agent.get_participant = lambda ritual_id, provider: participants[provider] # ensure no operation performed for non-application-state non_application_states = [ @@ -156,9 +154,7 @@ def test_perform_round_1( ursula.dkg_storage.store_transcript_receipt(ritual_id=0, txhash_or_receipt=None) # participant already posted transcript - participant = agent.get_participant_from_provider( - ritual_id=0, provider=ursula.checksum_address - ) + participant = agent.get_participant(ritual_id=0, provider=ursula.checksum_address) participant.transcript = bytes(random_transcript) # try submitting again @@ -213,9 +209,7 @@ def test_perform_round_2( agent.get_ritual = lambda *args, **kwargs: ritual agent.get_participants = lambda *args, **kwargs: participants - agent.get_participant_from_provider = lambda ritual_id, provider: participants[ - provider - ] + agent.get_participant = lambda ritual_id, provider: participants[provider] # ensure no operation performed for non-application-state non_application_states = [ @@ -253,9 +247,7 @@ def test_perform_round_2( ) # participant already posted aggregated transcript - participant = agent.get_participant_from_provider( - ritual_id=0, provider=ursula.checksum_address - ) + participant = agent.get_participant(ritual_id=0, provider=ursula.checksum_address) participant.aggregated = True # try submitting again