nucypher/tests/unit/test_coordinator.py

137 lines
4.4 KiB
Python
Raw Normal View History

from collections import OrderedDict
from unittest.mock import Mock
2023-02-01 16:53:02 +00:00
import pytest
from eth_account import Account
from nucypher_core import RequestSecretKey
2023-02-01 16:53:02 +00:00
from tests.mock.coordinator import MockCoordinatorAgent
from tests.mock.interfaces import MockBlockchain
2023-02-01 16:53:02 +00:00
DKG_SIZE = 4
2023-02-01 16:53:02 +00:00
@pytest.fixture(scope='module')
def nodes_transacting_powers():
accounts = OrderedDict()
for _ in range(DKG_SIZE):
2023-02-01 16:53:02 +00:00
account = Account.create()
mock_transacting_power = Mock()
mock_transacting_power.account = account.address
accounts[account.address] = mock_transacting_power
2023-02-01 16:53:02 +00:00
return accounts
@pytest.fixture(scope='module')
def coordinator():
return MockCoordinatorAgent(blockchain=MockBlockchain())
2023-02-01 16:53:02 +00:00
def test_mock_coordinator_creation(coordinator):
assert len(coordinator.rituals) == 0
def test_mock_coordinator_initiation(mocker, nodes_transacting_powers, coordinator, random_address):
2023-02-01 16:53:02 +00:00
assert len(coordinator.rituals) == 0
mock_transacting_power = mocker.Mock()
mock_transacting_power.account = random_address
coordinator.initiate_ritual(
providers=list(nodes_transacting_powers.keys()),
transacting_power=mock_transacting_power,
)
2023-02-01 16:53:02 +00:00
assert len(coordinator.rituals) == 1
assert coordinator.number_of_rituals() == 1
ritual = coordinator.rituals[0]
assert len(ritual.participants) == DKG_SIZE
2023-04-03 17:09:35 +00:00
for p in ritual.participants:
2023-02-01 16:53:02 +00:00
assert p.transcript == bytes()
assert len(coordinator.EVENTS) == 1
2023-04-03 17:09:35 +00:00
timestamp, signal = list(coordinator.EVENTS.items())[0]
2023-02-01 16:53:02 +00:00
signal_type, signal_data = signal
assert signal_type == MockCoordinatorAgent.Events.START_RITUAL
assert signal_data["ritual_id"] == 0
assert signal_data["initiator"] == mock_transacting_power.account
assert set(signal_data["participants"]) == nodes_transacting_powers.keys()
2023-02-01 16:53:02 +00:00
def test_mock_coordinator_round_1(
nodes_transacting_powers, coordinator, random_transcript
):
2023-02-01 16:53:02 +00:00
ritual = coordinator.rituals[0]
assert (
coordinator.get_ritual_status(0)
== MockCoordinatorAgent.RitualStatus.AWAITING_TRANSCRIPTS
)
2023-02-01 16:53:02 +00:00
2023-04-03 17:09:35 +00:00
for p in ritual.participants:
2023-02-01 16:53:02 +00:00
assert p.transcript == bytes()
for index, node_address in enumerate(nodes_transacting_powers):
transcript = random_transcript
2023-02-01 16:53:02 +00:00
coordinator.post_transcript(
ritual_id=0,
transcript=transcript,
transacting_power=nodes_transacting_powers[node_address]
2023-02-01 16:53:02 +00:00
)
2023-04-03 17:09:35 +00:00
performance = ritual.participants[index]
assert performance.transcript == bytes(transcript)
2023-02-01 16:53:02 +00:00
if index == len(nodes_transacting_powers) - 1:
assert len(coordinator.EVENTS) == 2
2023-02-01 16:53:02 +00:00
timestamp, signal = list(coordinator.EVENTS.items())[1]
2023-02-01 16:53:02 +00:00
signal_type, signal_data = signal
assert signal_type == MockCoordinatorAgent.Events.START_AGGREGATION_ROUND
assert signal_data["ritual_id"] == 0
2023-02-01 16:53:02 +00:00
def test_mock_coordinator_round_2(
nodes_transacting_powers,
coordinator,
aggregated_transcript,
dkg_public_key,
random_transcript,
):
2023-02-01 16:53:02 +00:00
ritual = coordinator.rituals[0]
assert (
coordinator.get_ritual_status(0)
== MockCoordinatorAgent.RitualStatus.AWAITING_AGGREGATIONS
)
2023-02-01 16:53:02 +00:00
2023-04-03 17:09:35 +00:00
for p in ritual.participants:
assert p.transcript == bytes(random_transcript)
2023-02-01 16:53:02 +00:00
participant_public_keys = []
for index, node_address in enumerate(nodes_transacting_powers):
participant_public_key = RequestSecretKey.random().public_key()
2023-04-03 17:09:35 +00:00
coordinator.post_aggregation(
2023-02-01 16:53:02 +00:00
ritual_id=0,
aggregated_transcript=aggregated_transcript,
public_key=dkg_public_key,
participant_public_key=participant_public_key,
transacting_power=nodes_transacting_powers[node_address]
2023-02-01 16:53:02 +00:00
)
participant_public_keys.append(participant_public_key)
if index == len(nodes_transacting_powers) - 1:
assert len(coordinator.EVENTS) == 2
assert ritual.aggregated_transcript == bytes(aggregated_transcript)
assert bytes(ritual.public_key) == bytes(dkg_public_key)
for index, p in enumerate(ritual.participants):
# unchanged
assert p.transcript == bytes(random_transcript)
assert p.transcript != bytes(aggregated_transcript)
assert p.requestEncryptingKey == bytes(participant_public_keys[index])
assert len(coordinator.EVENTS) == 2 # no additional event emitted here?
assert (
coordinator.get_ritual_status(0) == MockCoordinatorAgent.RitualStatus.FINALIZED
)