mirror of https://github.com/nucypher/nucypher.git
updates mocks to sync with coordinator agent updates
parent
9f8661c0bb
commit
c1495b66a7
|
@ -1,6 +1,6 @@
|
|||
import time
|
||||
from enum import Enum
|
||||
from typing import Dict, List, NamedTuple
|
||||
from typing import Dict, List, NamedTuple, Optional
|
||||
|
||||
from eth_typing import ChecksumAddress
|
||||
from eth_utils import keccak
|
||||
|
@ -121,7 +121,7 @@ class MockCoordinatorAgent(MockContractAgent):
|
|||
self._get_staking_provider_from_operator(operator=operator_address)
|
||||
or transacting_power.account
|
||||
)
|
||||
participant = self.get_participant_from_provider(ritual_id, provider)
|
||||
participant = self.get_participant(ritual_id, provider)
|
||||
participant.transcript = bytes(transcript)
|
||||
ritual.total_transcripts += 1
|
||||
if ritual.total_transcripts == ritual.dkg_size:
|
||||
|
@ -151,7 +151,7 @@ class MockCoordinatorAgent(MockContractAgent):
|
|||
self._get_staking_provider_from_operator(operator=operator_address)
|
||||
or transacting_power.account
|
||||
)
|
||||
participant = self.get_participant_from_provider(ritual_id, provider)
|
||||
participant = self.get_participant(ritual_id, provider)
|
||||
participant.aggregated = True
|
||||
participant.decryption_request_static_key = bytes(participant_public_key)
|
||||
|
||||
|
@ -207,23 +207,26 @@ class MockCoordinatorAgent(MockContractAgent):
|
|||
def number_of_rituals(self) -> int:
|
||||
return len(self.rituals)
|
||||
|
||||
def get_ritual(
|
||||
self, ritual_id: int, with_participants: bool = False
|
||||
) -> CoordinatorAgent.Ritual:
|
||||
def get_ritual(self, ritual_id: int) -> CoordinatorAgent.Ritual:
|
||||
return self.rituals[ritual_id]
|
||||
|
||||
def get_participants(self, ritual_id: int) -> List[Participant]:
|
||||
return self.rituals[ritual_id].participants
|
||||
def is_participant(self, ritual_id: int, provider: ChecksumAddress) -> bool:
|
||||
try:
|
||||
self.get_participant(ritual_id, provider)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def get_participant_from_provider(
|
||||
self, ritual_id: int, provider: ChecksumAddress
|
||||
) -> Participant:
|
||||
def get_participant(self, ritual_id: int, provider: ChecksumAddress) -> Participant:
|
||||
for p in self.rituals[ritual_id].participants:
|
||||
if p.provider == provider:
|
||||
return p
|
||||
|
||||
raise ValueError(f"Provider {provider} not found for ritual #{ritual_id}")
|
||||
|
||||
def get_providers(self, ritual_id: int) -> List[ChecksumAddress]:
|
||||
return [p.provider for p in self.rituals[ritual_id].participants]
|
||||
|
||||
def get_ritual_status(self, ritual_id: int) -> int:
|
||||
ritual = self.rituals[ritual_id]
|
||||
timestamp = int(ritual.init_timestamp)
|
||||
|
@ -259,7 +262,7 @@ class MockCoordinatorAgent(MockContractAgent):
|
|||
f"No ritual id found for public key 0x{bytes(public_key).hex()}"
|
||||
)
|
||||
|
||||
def get_ritual_public_key(self, ritual_id: int) -> DkgPublicKey:
|
||||
def get_ritual_public_key(self, ritual_id: int) -> Optional[DkgPublicKey]:
|
||||
status = self.get_ritual_status(ritual_id=ritual_id)
|
||||
if status != self.Ritual.Status.ACTIVE and status != self.Ritual.Status.EXPIRED:
|
||||
# TODO should we raise here instead?
|
||||
|
@ -277,8 +280,8 @@ class MockCoordinatorAgent(MockContractAgent):
|
|||
participant_keys = self._participant_keys_history[provider]
|
||||
for participant_key in reversed(participant_keys):
|
||||
if participant_key.lastRitualId <= ritual_id:
|
||||
g2Point = participant_key.publicKey
|
||||
return g2Point.to_public_key()
|
||||
g2_point = participant_key.publicKey
|
||||
return g2_point.to_public_key()
|
||||
|
||||
raise ValueError(
|
||||
f"Public key not found for provider {provider} for ritual #{ritual_id}"
|
||||
|
|
|
@ -113,9 +113,7 @@ def test_perform_round_1(
|
|||
)
|
||||
agent.get_ritual = lambda *args, **kwargs: ritual
|
||||
agent.get_participants = lambda *args, **kwargs: participants
|
||||
agent.get_participant_from_provider = lambda ritual_id, provider: participants[
|
||||
provider
|
||||
]
|
||||
agent.get_participant = lambda ritual_id, provider: participants[provider]
|
||||
|
||||
# ensure no operation performed for non-application-state
|
||||
non_application_states = [
|
||||
|
@ -156,9 +154,7 @@ def test_perform_round_1(
|
|||
ursula.dkg_storage.store_transcript_receipt(ritual_id=0, txhash_or_receipt=None)
|
||||
|
||||
# participant already posted transcript
|
||||
participant = agent.get_participant_from_provider(
|
||||
ritual_id=0, provider=ursula.checksum_address
|
||||
)
|
||||
participant = agent.get_participant(ritual_id=0, provider=ursula.checksum_address)
|
||||
participant.transcript = bytes(random_transcript)
|
||||
|
||||
# try submitting again
|
||||
|
@ -213,9 +209,7 @@ def test_perform_round_2(
|
|||
|
||||
agent.get_ritual = lambda *args, **kwargs: ritual
|
||||
agent.get_participants = lambda *args, **kwargs: participants
|
||||
agent.get_participant_from_provider = lambda ritual_id, provider: participants[
|
||||
provider
|
||||
]
|
||||
agent.get_participant = lambda ritual_id, provider: participants[provider]
|
||||
|
||||
# ensure no operation performed for non-application-state
|
||||
non_application_states = [
|
||||
|
@ -253,9 +247,7 @@ def test_perform_round_2(
|
|||
)
|
||||
|
||||
# participant already posted aggregated transcript
|
||||
participant = agent.get_participant_from_provider(
|
||||
ritual_id=0, provider=ursula.checksum_address
|
||||
)
|
||||
participant = agent.get_participant(ritual_id=0, provider=ursula.checksum_address)
|
||||
participant.aggregated = True
|
||||
|
||||
# try submitting again
|
||||
|
|
Loading…
Reference in New Issue