mirror of https://github.com/nucypher/nucypher.git
handle shared coordinator reads and dkg artifact preperation logic between ursula and decryptors.
parent
8cd51eebf5
commit
96cac14fd5
|
@ -291,24 +291,6 @@ class Operator(BaseActor):
|
||||||
|
|
||||||
return providers
|
return providers
|
||||||
|
|
||||||
def get_ritual(
|
|
||||||
self, ritual_id: int, download_transcripts: bool
|
|
||||||
) -> CoordinatorAgent.Ritual:
|
|
||||||
"""assembles Coordinator data using multiple RPC requests."""
|
|
||||||
ritual = self.coordinator_agent.get_ritual(ritual_id) # call 1
|
|
||||||
addresses = self.coordinator_agent.get_providers(ritual_id) # call 2
|
|
||||||
participants = []
|
|
||||||
for index, address in enumerate(addresses): # n calls
|
|
||||||
if download_transcripts:
|
|
||||||
participant = self.coordinator_agent.get_participant(ritual_id, address)
|
|
||||||
else:
|
|
||||||
participant = CoordinatorAgent.Ritual.Participant(
|
|
||||||
index=index, provider=address
|
|
||||||
)
|
|
||||||
participants.append(participant)
|
|
||||||
ritual.participants = participants
|
|
||||||
return ritual
|
|
||||||
|
|
||||||
def _resolve_validators(
|
def _resolve_validators(
|
||||||
self,
|
self,
|
||||||
ritual: CoordinatorAgent.Ritual,
|
ritual: CoordinatorAgent.Ritual,
|
||||||
|
@ -398,7 +380,7 @@ class Operator(BaseActor):
|
||||||
|
|
||||||
# validate the active ritual tracker state
|
# validate the active ritual tracker state
|
||||||
participant = self.coordinator_agent.get_participant(
|
participant = self.coordinator_agent.get_participant(
|
||||||
ritual_id=ritual_id, provider=self.checksum_address
|
ritual_id=ritual_id, provider=self.checksum_address, transcript=False
|
||||||
)
|
)
|
||||||
if participant.transcript:
|
if participant.transcript:
|
||||||
self.log.debug(
|
self.log.debug(
|
||||||
|
@ -425,9 +407,9 @@ class Operator(BaseActor):
|
||||||
f"performing round 1 of DKG ritual #{ritual_id} from blocktime {timestamp} with authority {authority}."
|
f"performing round 1 of DKG ritual #{ritual_id} from blocktime {timestamp} with authority {authority}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# gather the ritual metadata
|
# gather the ritual metadata and DKG artifacts
|
||||||
ritual = self.get_ritual(ritual_id, download_transcripts=False)
|
ritual = self.coordinator_agent.get_ritual(ritual_id)
|
||||||
validators = list(zip(*self._resolve_validators(ritual, ritual_id)))
|
validators = self._resolve_validators(ritual, ritual_id)
|
||||||
|
|
||||||
# generate a transcript
|
# generate a transcript
|
||||||
try:
|
try:
|
||||||
|
@ -473,7 +455,7 @@ class Operator(BaseActor):
|
||||||
|
|
||||||
# validate the active ritual tracker state
|
# validate the active ritual tracker state
|
||||||
participant = self.coordinator_agent.get_participant(
|
participant = self.coordinator_agent.get_participant(
|
||||||
ritual_id=ritual_id, provider=self.checksum_address
|
ritual_id=ritual_id, provider=self.checksum_address, transcript=False
|
||||||
)
|
)
|
||||||
if participant.aggregated:
|
if participant.aggregated:
|
||||||
self.log.debug(
|
self.log.debug(
|
||||||
|
@ -496,13 +478,18 @@ class Operator(BaseActor):
|
||||||
f"{self.transacting_power.account[:8]} performing round 2 of DKG ritual #{ritual_id} from blocktime {timestamp}"
|
f"{self.transacting_power.account[:8]} performing round 2 of DKG ritual #{ritual_id} from blocktime {timestamp}"
|
||||||
)
|
)
|
||||||
|
|
||||||
ritual = self.coordinator_agent.get_ritual(ritual_id, with_participants=True)
|
ritual = self.coordinator_agent.get_ritual(ritual_id, transcripts=True)
|
||||||
nodes, transcripts = self._resolve_validators(ritual, ritual_id)
|
missing = sum(1 for t in ritual.transcripts if not t)
|
||||||
if not all(transcripts):
|
if missing:
|
||||||
raise self.ActorError(
|
raise self.ActorError(
|
||||||
f"ritual #{ritual_id} is missing transcripts from {len([t for t in transcripts if not t])} nodes."
|
f"ritual #{ritual_id} is missing transcripts from {missing} nodes."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Prepare the DKG artifacts
|
||||||
|
validators = self._resolve_validators(ritual, ritual_id)
|
||||||
|
transcripts = (Transcript.from_bytes(bytes(t)) for t in ritual.transcripts)
|
||||||
|
messages = list(zip(validators, transcripts))
|
||||||
|
|
||||||
# Aggregate the transcripts
|
# Aggregate the transcripts
|
||||||
try:
|
try:
|
||||||
result = self.ritual_power.aggregate_transcripts(
|
result = self.ritual_power.aggregate_transcripts(
|
||||||
|
@ -510,7 +497,7 @@ class Operator(BaseActor):
|
||||||
shares=ritual.shares,
|
shares=ritual.shares,
|
||||||
checksum_address=self.checksum_address,
|
checksum_address=self.checksum_address,
|
||||||
ritual_id=ritual_id,
|
ritual_id=ritual_id,
|
||||||
transcripts=transcripts
|
transcripts=messages,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log.debug(f"Failed to aggregate transcripts for ritual #{ritual_id}: {str(e)}")
|
self.log.debug(f"Failed to aggregate transcripts for ritual #{ritual_id}: {str(e)}")
|
||||||
|
@ -558,16 +545,10 @@ class Operator(BaseActor):
|
||||||
ritual = self.coordinator_agent.get_ritual(ritual_id)
|
ritual = self.coordinator_agent.get_ritual(ritual_id)
|
||||||
if not self.coordinator_agent.is_ritual_active(ritual_id=ritual_id):
|
if not self.coordinator_agent.is_ritual_active(ritual_id=ritual_id):
|
||||||
raise self.ActorError(f"Ritual #{ritual_id} is not active.")
|
raise self.ActorError(f"Ritual #{ritual_id} is not active.")
|
||||||
|
validators = list(self._resolve_validators(ritual, ritual_id))
|
||||||
nodes, transcripts = list(zip(*self._resolve_validators(ritual, ritual_id)))
|
|
||||||
if not all(transcripts):
|
|
||||||
raise self.ActorError(f"Ritual #{ritual_id} is missing transcripts")
|
|
||||||
|
|
||||||
# TODO: consider the usage of local DKG artifact storage here #3052
|
|
||||||
# aggregated_transcript_bytes = self.dkg_storage.get_aggregated_transcript(ritual_id)
|
|
||||||
aggregated_transcript = AggregatedTranscript.from_bytes(bytes(ritual.aggregated_transcript))
|
aggregated_transcript = AggregatedTranscript.from_bytes(bytes(ritual.aggregated_transcript))
|
||||||
decryption_share = self.ritual_power.derive_decryption_share(
|
decryption_share = self.ritual_power.derive_decryption_share(
|
||||||
nodes=nodes,
|
nodes=validators,
|
||||||
threshold=ritual.threshold,
|
threshold=ritual.threshold,
|
||||||
shares=ritual.shares,
|
shares=ritual.shares,
|
||||||
checksum_address=self.checksum_address,
|
checksum_address=self.checksum_address,
|
||||||
|
@ -577,7 +558,6 @@ class Operator(BaseActor):
|
||||||
aad=aad,
|
aad=aad,
|
||||||
variant=variant
|
variant=variant
|
||||||
)
|
)
|
||||||
|
|
||||||
return decryption_share
|
return decryption_share
|
||||||
|
|
||||||
def decrypt_threshold_decryption_request(
|
def decrypt_threshold_decryption_request(
|
||||||
|
@ -602,7 +582,7 @@ class Operator(BaseActor):
|
||||||
# dkg_public_key = this_node.dkg_storage.get_public_key(decryption_request.ritual_id)
|
# dkg_public_key = this_node.dkg_storage.get_public_key(decryption_request.ritual_id)
|
||||||
|
|
||||||
# enforces that the node is part of the ritual
|
# enforces that the node is part of the ritual
|
||||||
participating = self.coordinator_agent.is_provider_participating(
|
participating = self.coordinator_agent.is_participant(
|
||||||
ritual_id=decryption_request.ritual_id, provider=self.checksum_address
|
ritual_id=decryption_request.ritual_id, provider=self.checksum_address
|
||||||
)
|
)
|
||||||
if not participating:
|
if not participating:
|
||||||
|
|
|
@ -723,11 +723,8 @@ class CoordinatorAgent(EthereumContractAgent):
|
||||||
return [p.provider for p in self.participants]
|
return [p.provider for p in self.participants]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def transcripts(self) -> List[Tuple[ChecksumAddress, bytes]]:
|
def transcripts(self) -> Iterable[bytes]:
|
||||||
transcripts = list()
|
return [p.transcript for p in self.participants]
|
||||||
for p in self.participants:
|
|
||||||
transcripts.append((p.provider, p.transcript))
|
|
||||||
return transcripts
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shares(self) -> int:
|
def shares(self) -> int:
|
||||||
|
@ -748,7 +745,7 @@ class CoordinatorAgent(EthereumContractAgent):
|
||||||
return self.contract.functions.timeout().call()
|
return self.contract.functions.timeout().call()
|
||||||
|
|
||||||
@contract_api(CONTRACT_CALL)
|
@contract_api(CONTRACT_CALL)
|
||||||
def get_ritual(self, ritual_id: int) -> Ritual:
|
def rituals(self, ritual_id: int) -> Ritual:
|
||||||
result = self.contract.functions.rituals(int(ritual_id)).call()
|
result = self.contract.functions.rituals(int(ritual_id)).call()
|
||||||
ritual = self.Ritual(
|
ritual = self.Ritual(
|
||||||
initiator=ChecksumAddress(result[0]),
|
initiator=ChecksumAddress(result[0]),
|
||||||
|
@ -764,11 +761,32 @@ class CoordinatorAgent(EthereumContractAgent):
|
||||||
aggregated_transcript=bytes(result[11]),
|
aggregated_transcript=bytes(result[11]),
|
||||||
participants=[], # solidity does not return sub-structs
|
participants=[], # solidity does not return sub-structs
|
||||||
)
|
)
|
||||||
|
|
||||||
# public key
|
# public key
|
||||||
ritual.public_key = self.Ritual.G1Point(result[10][0], result[10][1])
|
ritual.public_key = self.Ritual.G1Point(result[10][0], result[10][1])
|
||||||
return ritual
|
return ritual
|
||||||
|
|
||||||
|
def get_ritual(
|
||||||
|
self,
|
||||||
|
ritual_id: int,
|
||||||
|
transcripts: bool = False,
|
||||||
|
participants: bool = True,
|
||||||
|
) -> Ritual:
|
||||||
|
"""assembles Coordinator data using multiple RPC requests."""
|
||||||
|
if not participants and transcripts:
|
||||||
|
raise ValueError("Cannot get transcripts without participants")
|
||||||
|
|
||||||
|
ritual = self.rituals(ritual_id) # call 1
|
||||||
|
if not participants:
|
||||||
|
return ritual
|
||||||
|
|
||||||
|
participants = []
|
||||||
|
addresses = self.get_providers(ritual_id) # call 2
|
||||||
|
for index, address in enumerate(addresses): # n calls
|
||||||
|
participant = self.get_participant(ritual_id, address, transcripts)
|
||||||
|
participants.append(participant)
|
||||||
|
ritual.participants = participants
|
||||||
|
return ritual
|
||||||
|
|
||||||
@contract_api(CONTRACT_CALL)
|
@contract_api(CONTRACT_CALL)
|
||||||
def get_ritual_status(self, ritual_id: int) -> int:
|
def get_ritual_status(self, ritual_id: int) -> int:
|
||||||
result = self.contract.functions.getRitualState(ritual_id).call()
|
result = self.contract.functions.getRitualState(ritual_id).call()
|
||||||
|
@ -791,9 +809,11 @@ class CoordinatorAgent(EthereumContractAgent):
|
||||||
|
|
||||||
@contract_api(CONTRACT_CALL)
|
@contract_api(CONTRACT_CALL)
|
||||||
def get_participant(
|
def get_participant(
|
||||||
self, ritual_id: int, provider: ChecksumAddress
|
self, ritual_id: int, provider: ChecksumAddress, transcript: bool
|
||||||
) -> Ritual.Participant:
|
) -> Ritual.Participant:
|
||||||
data, index = self.contract.functions.getParticipant(ritual_id, provider).call()
|
data, index = self.contract.functions.getParticipant(
|
||||||
|
ritual_id, provider, transcript
|
||||||
|
).call()
|
||||||
participant = self.Ritual.Participant(
|
participant = self.Ritual.Participant(
|
||||||
index=index,
|
index=index,
|
||||||
provider=ChecksumAddress(data[0]),
|
provider=ChecksumAddress(data[0]),
|
||||||
|
@ -805,7 +825,7 @@ class CoordinatorAgent(EthereumContractAgent):
|
||||||
|
|
||||||
@contract_api(CONTRACT_CALL)
|
@contract_api(CONTRACT_CALL)
|
||||||
def get_provider_public_key(
|
def get_provider_public_key(
|
||||||
self, provider: ChecksumAddress, ritual_id: int
|
self, provider: ChecksumAddress, ritual_id: int
|
||||||
) -> FerveoPublicKey:
|
) -> FerveoPublicKey:
|
||||||
result = self.contract.functions.getProviderPublicKey(
|
result = self.contract.functions.getProviderPublicKey(
|
||||||
provider, ritual_id
|
provider, ritual_id
|
||||||
|
|
|
@ -246,7 +246,7 @@ class ActiveRitualTracker:
|
||||||
Returns node's participant information for the provided
|
Returns node's participant information for the provided
|
||||||
ritual id; None if node is not participating in the ritual
|
ritual id; None if node is not participating in the ritual
|
||||||
"""
|
"""
|
||||||
participants = self.coordinator_agent.get_participants(ritual_id=ritual_id)
|
participants = self.coordinator_agent.get_ritual(ritual_id=ritual_id)
|
||||||
for p in participants:
|
for p in participants:
|
||||||
if p.provider == self.operator.checksum_address:
|
if p.provider == self.operator.checksum_address:
|
||||||
return p
|
return p
|
||||||
|
|
|
@ -694,8 +694,9 @@ class Bob(Character):
|
||||||
)
|
)
|
||||||
return ritual_id
|
return ritual_id
|
||||||
|
|
||||||
def get_ritual_from_id(self, ritual_id) -> CoordinatorAgent.Ritual:
|
def get_ritual(self, ritual_id) -> CoordinatorAgent.Ritual:
|
||||||
ritual = self._get_coordinator_agent().get_ritual(ritual_id)
|
agent = self._get_coordinator_agent()
|
||||||
|
ritual = agent.get_ritual(ritual_id)
|
||||||
return ritual
|
return ritual
|
||||||
|
|
||||||
def threshold_decrypt(
|
def threshold_decrypt(
|
||||||
|
@ -708,7 +709,7 @@ class Bob(Character):
|
||||||
ritual_id = self.get_ritual_id_from_public_key(
|
ritual_id = self.get_ritual_id_from_public_key(
|
||||||
public_key=threshold_message_kit.acp.public_key
|
public_key=threshold_message_kit.acp.public_key
|
||||||
)
|
)
|
||||||
ritual = self.get_ritual_from_id(ritual_id=ritual_id)
|
ritual = self.get_ritual(ritual_id=ritual_id)
|
||||||
|
|
||||||
if ursulas:
|
if ursulas:
|
||||||
for ursula in ursulas:
|
for ursula in ursulas:
|
||||||
|
@ -726,10 +727,9 @@ class Bob(Character):
|
||||||
variant=variant,
|
variant=variant,
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
participant_public_keys = ritual.participant_public_keys
|
|
||||||
decryption_shares = self._get_decryption_shares(
|
decryption_shares = self._get_decryption_shares(
|
||||||
decryption_request=decryption_request,
|
decryption_request=decryption_request,
|
||||||
participant_public_keys=participant_public_keys,
|
participant_public_keys=ritual.participant_public_keys,
|
||||||
threshold=ritual.threshold,
|
threshold=ritual.threshold,
|
||||||
timeout=decryption_timeout,
|
timeout=decryption_timeout,
|
||||||
)
|
)
|
||||||
|
|
|
@ -166,7 +166,9 @@ def test_ursula_ritualist(
|
||||||
assert (
|
assert (
|
||||||
len(
|
len(
|
||||||
coordinator_agent.get_participant(
|
coordinator_agent.get_participant(
|
||||||
ritual_id=RITUAL_ID, provider=ursula.checksum_address
|
ritual_id=RITUAL_ID,
|
||||||
|
provider=ursula.checksum_address,
|
||||||
|
transcript=True,
|
||||||
).transcript
|
).transcript
|
||||||
)
|
)
|
||||||
> 0
|
> 0
|
||||||
|
|
|
@ -99,8 +99,8 @@ def test_initiate_ritual(
|
||||||
ritual = agent.get_ritual(ritual_id)
|
ritual = agent.get_ritual(ritual_id)
|
||||||
assert ritual.authority == authority
|
assert ritual.authority == authority
|
||||||
|
|
||||||
participants = agent.get_participants(ritual_id)
|
ritual = agent.get_ritual(ritual_id)
|
||||||
assert [p.provider for p in participants] == cohort
|
assert [p.provider for p in ritual.participants] == cohort
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
agent.get_ritual_status(ritual_id=ritual_id)
|
agent.get_ritual_status(ritual_id=ritual_id)
|
||||||
|
@ -128,8 +128,8 @@ def test_post_transcript(agent, transcripts, transacting_powers):
|
||||||
assert event["args"]["ritualId"] == ritual_id
|
assert event["args"]["ritualId"] == ritual_id
|
||||||
assert event["args"]["transcriptDigest"] == keccak(transcripts[i])
|
assert event["args"]["transcriptDigest"] == keccak(transcripts[i])
|
||||||
|
|
||||||
participants = agent.get_participants(ritual_id)
|
ritual = agent.get_ritual(ritual_id, transcripts=True)
|
||||||
assert [p.transcript for p in participants] == transcripts
|
assert [p.transcript for p in ritual.participants] == transcripts
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
agent.get_ritual_status(ritual_id=ritual_id)
|
agent.get_ritual_status(ritual_id=ritual_id)
|
||||||
|
@ -167,7 +167,7 @@ def test_post_aggregation(
|
||||||
bytes(aggregated_transcript)
|
bytes(aggregated_transcript)
|
||||||
)
|
)
|
||||||
|
|
||||||
participants = agent.get_participants(ritual_id)
|
participants = agent.get_ritual(ritual_id).participants
|
||||||
for p in participants:
|
for p in participants:
|
||||||
assert p.aggregated
|
assert p.aggregated
|
||||||
assert p.decryption_request_static_key == bytes(
|
assert p.decryption_request_static_key == bytes(
|
||||||
|
|
|
@ -243,7 +243,6 @@ def test_first_scan_start_block_calc_is_not_perfect_go_back_more_blocks(ritualis
|
||||||
|
|
||||||
|
|
||||||
def test_get_ritual_participant_info(ritualist, get_random_checksum_address):
|
def test_get_ritual_participant_info(ritualist, get_random_checksum_address):
|
||||||
mocked_agent = ritualist.coordinator_agent
|
|
||||||
active_ritual_tracker = ActiveRitualTracker(operator=ritualist)
|
active_ritual_tracker = ActiveRitualTracker(operator=ritualist)
|
||||||
|
|
||||||
participants = []
|
participants = []
|
||||||
|
@ -253,7 +252,6 @@ def test_get_ritual_participant_info(ritualist, get_random_checksum_address):
|
||||||
index=i, provider=get_random_checksum_address()
|
index=i, provider=get_random_checksum_address()
|
||||||
)
|
)
|
||||||
participants.append(participant)
|
participants.append(participant)
|
||||||
mocked_agent.get_participants.return_value = participants
|
|
||||||
|
|
||||||
# operator not in participants list
|
# operator not in participants list
|
||||||
participant_info = active_ritual_tracker._get_ritual_participant_info(ritual_id=0)
|
participant_info = active_ritual_tracker._get_ritual_participant_info(ritual_id=0)
|
||||||
|
@ -274,7 +272,6 @@ def test_get_ritual_participant_info(ritualist, get_random_checksum_address):
|
||||||
def test_get_participation_state_values_from_contract(
|
def test_get_participation_state_values_from_contract(
|
||||||
ritualist, get_random_checksum_address
|
ritualist, get_random_checksum_address
|
||||||
):
|
):
|
||||||
mocked_agent = ritualist.coordinator_agent
|
|
||||||
active_ritual_tracker = ActiveRitualTracker(operator=ritualist)
|
active_ritual_tracker = ActiveRitualTracker(operator=ritualist)
|
||||||
|
|
||||||
participants = []
|
participants = []
|
||||||
|
@ -285,8 +282,6 @@ def test_get_participation_state_values_from_contract(
|
||||||
)
|
)
|
||||||
participants.append(participant)
|
participants.append(participant)
|
||||||
|
|
||||||
mocked_agent.get_participants.return_value = participants
|
|
||||||
|
|
||||||
# not participating so everything should be False
|
# not participating so everything should be False
|
||||||
(
|
(
|
||||||
participating,
|
participating,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import time
|
import time
|
||||||
|
from copy import deepcopy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, NamedTuple, Optional
|
from typing import Dict, List, NamedTuple, Optional
|
||||||
|
|
||||||
|
@ -121,7 +122,7 @@ class MockCoordinatorAgent(MockContractAgent):
|
||||||
self._get_staking_provider_from_operator(operator=operator_address)
|
self._get_staking_provider_from_operator(operator=operator_address)
|
||||||
or transacting_power.account
|
or transacting_power.account
|
||||||
)
|
)
|
||||||
participant = self.get_participant(ritual_id, provider)
|
participant = self.get_participant(ritual_id, provider, False)
|
||||||
participant.transcript = bytes(transcript)
|
participant.transcript = bytes(transcript)
|
||||||
ritual.total_transcripts += 1
|
ritual.total_transcripts += 1
|
||||||
if ritual.total_transcripts == ritual.dkg_size:
|
if ritual.total_transcripts == ritual.dkg_size:
|
||||||
|
@ -151,7 +152,7 @@ class MockCoordinatorAgent(MockContractAgent):
|
||||||
self._get_staking_provider_from_operator(operator=operator_address)
|
self._get_staking_provider_from_operator(operator=operator_address)
|
||||||
or transacting_power.account
|
or transacting_power.account
|
||||||
)
|
)
|
||||||
participant = self.get_participant(ritual_id, provider)
|
participant = self.get_participant(ritual_id, provider, True)
|
||||||
participant.aggregated = True
|
participant.aggregated = True
|
||||||
participant.decryption_request_static_key = bytes(participant_public_key)
|
participant.decryption_request_static_key = bytes(participant_public_key)
|
||||||
|
|
||||||
|
@ -207,19 +208,27 @@ class MockCoordinatorAgent(MockContractAgent):
|
||||||
def number_of_rituals(self) -> int:
|
def number_of_rituals(self) -> int:
|
||||||
return len(self.rituals)
|
return len(self.rituals)
|
||||||
|
|
||||||
def get_ritual(self, ritual_id: int) -> CoordinatorAgent.Ritual:
|
def get_ritual(
|
||||||
return self.rituals[ritual_id]
|
self, ritual_id: int, transcripts: bool = False, participants: bool = True
|
||||||
|
) -> CoordinatorAgent.Ritual:
|
||||||
|
ritual = deepcopy(self.rituals[ritual_id])
|
||||||
|
return ritual
|
||||||
|
|
||||||
def is_participant(self, ritual_id: int, provider: ChecksumAddress) -> bool:
|
def is_participant(self, ritual_id: int, provider: ChecksumAddress) -> bool:
|
||||||
try:
|
try:
|
||||||
self.get_participant(ritual_id, provider)
|
self.get_participant(ritual_id, provider, False)
|
||||||
return True
|
return True
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_participant(self, ritual_id: int, provider: ChecksumAddress) -> Participant:
|
def get_participant(
|
||||||
|
self, ritual_id: int, provider: ChecksumAddress, transcript: bool
|
||||||
|
) -> Participant:
|
||||||
for p in self.rituals[ritual_id].participants:
|
for p in self.rituals[ritual_id].participants:
|
||||||
if p.provider == provider:
|
if p.provider == provider:
|
||||||
|
# if not transcripts:
|
||||||
|
# p = deepcopy(p)
|
||||||
|
# p.transcript = b""
|
||||||
return p
|
return p
|
||||||
|
|
||||||
raise ValueError(f"Provider {provider} not found for ritual #{ritual_id}")
|
raise ValueError(f"Provider {provider} not found for ritual #{ritual_id}")
|
||||||
|
|
|
@ -75,7 +75,6 @@ def test_initiate_ritual(
|
||||||
participants=participants,
|
participants=participants,
|
||||||
)
|
)
|
||||||
agent.get_ritual = lambda *args, **kwargs: ritual
|
agent.get_ritual = lambda *args, **kwargs: ritual
|
||||||
agent.get_participants = lambda *args, **kwargs: participants
|
|
||||||
|
|
||||||
assert receipt["transactionHash"]
|
assert receipt["transactionHash"]
|
||||||
number_of_rituals = agent.number_of_rituals()
|
number_of_rituals = agent.number_of_rituals()
|
||||||
|
@ -112,8 +111,9 @@ def test_perform_round_1(
|
||||||
participants=list(participants.values()),
|
participants=list(participants.values()),
|
||||||
)
|
)
|
||||||
agent.get_ritual = lambda *args, **kwargs: ritual
|
agent.get_ritual = lambda *args, **kwargs: ritual
|
||||||
agent.get_participants = lambda *args, **kwargs: participants
|
agent.get_participant = lambda ritual_id, provider, transcripts: participants[
|
||||||
agent.get_participant = lambda ritual_id, provider: participants[provider]
|
provider
|
||||||
|
]
|
||||||
|
|
||||||
# ensure no operation performed for non-application-state
|
# ensure no operation performed for non-application-state
|
||||||
non_application_states = [
|
non_application_states = [
|
||||||
|
@ -154,7 +154,9 @@ def test_perform_round_1(
|
||||||
ursula.dkg_storage.store_transcript_receipt(ritual_id=0, txhash_or_receipt=None)
|
ursula.dkg_storage.store_transcript_receipt(ritual_id=0, txhash_or_receipt=None)
|
||||||
|
|
||||||
# participant already posted transcript
|
# participant already posted transcript
|
||||||
participant = agent.get_participant(ritual_id=0, provider=ursula.checksum_address)
|
participant = agent.get_participant(
|
||||||
|
ritual_id=0, provider=ursula.checksum_address, transcript=False
|
||||||
|
)
|
||||||
participant.transcript = bytes(random_transcript)
|
participant.transcript = bytes(random_transcript)
|
||||||
|
|
||||||
# try submitting again
|
# try submitting again
|
||||||
|
@ -208,7 +210,6 @@ def test_perform_round_2(
|
||||||
)
|
)
|
||||||
|
|
||||||
agent.get_ritual = lambda *args, **kwargs: ritual
|
agent.get_ritual = lambda *args, **kwargs: ritual
|
||||||
agent.get_participants = lambda *args, **kwargs: participants
|
|
||||||
agent.get_participant = lambda ritual_id, provider: participants[provider]
|
agent.get_participant = lambda ritual_id, provider: participants[provider]
|
||||||
|
|
||||||
# ensure no operation performed for non-application-state
|
# ensure no operation performed for non-application-state
|
||||||
|
|
Loading…
Reference in New Issue