Obtain request encrypting keys from Ritual object; don't convert object to bytes (transcript, agg transcript) before passing to Coordinator agent.

Update tests.
pull/3123/head
derekpierre 2023-05-19 13:35:05 -04:00
parent 0f9e044075
commit b9f0f0838a
11 changed files with 126 additions and 73 deletions

View File

@ -374,7 +374,7 @@ class Ritualist(BaseActor):
# look up the node index for this node on the blockchain
receipt = self.coordinator_agent.post_transcript(
ritual_id=ritual_id,
transcript=bytes(transcript),
transcript=transcript,
transacting_power=self.transacting_power
)
return receipt
@ -389,10 +389,10 @@ class Ritualist(BaseActor):
# look up the node index for this node on the blockchain
request_encrypting_key = self.threshold_request_power.get_pubkey_from_ritual_id(
ritual_id
).to_compressed_bytes()
)
receipt = self.coordinator_agent.post_aggregation(
ritual_id=ritual_id,
aggregated_transcript=bytes(aggregated_transcript),
aggregated_transcript=aggregated_transcript,
public_key=public_key,
request_encrypting_key=request_encrypting_key,
transacting_power=self.transacting_power

View File

@ -10,7 +10,8 @@ from constant_sorrow.constants import CONTRACT_ATTRIBUTE # type: ignore
from constant_sorrow.constants import CONTRACT_CALL, TRANSACTION
from eth_typing.evm import ChecksumAddress
from eth_utils.address import to_checksum_address
from ferveo_py.ferveo_py import DkgPublicKey
from ferveo_py.ferveo_py import AggregatedTranscript, DkgPublicKey, Transcript
from nucypher_core.umbral import PublicKey
from web3.contract.contract import Contract, ContractFunction
from web3.types import Timestamp, TxParams, TxReceipt, Wei
@ -611,6 +612,16 @@ class CoordinatorAgent(EthereumContractAgent):
def shares(self) -> int:
return len(self.providers)
@property
def request_encrypting_keys(self):
request_encrypting_keys = {}
for p in self.participants:
request_encrypting_keys[p.provider] = PublicKey.from_compressed_bytes(
p.requestEncryptingKey
)
return request_encrypting_keys
@contract_api(CONTRACT_CALL)
def get_timeout(self) -> int:
return self.contract.functions.timeout().call()
@ -693,12 +704,11 @@ class CoordinatorAgent(EthereumContractAgent):
def post_transcript(
self,
ritual_id: int,
transcript: bytes,
transcript: Transcript,
transacting_power: TransactingPower,
) -> TxReceipt:
contract_function: ContractFunction = self.contract.functions.postTranscript(
ritualId=ritual_id,
transcript=transcript
ritualId=ritual_id, transcript=bytes(transcript)
)
receipt = self.blockchain.send_transaction(contract_function=contract_function,
transacting_power=transacting_power)
@ -708,16 +718,16 @@ class CoordinatorAgent(EthereumContractAgent):
def post_aggregation(
self,
ritual_id: int,
aggregated_transcript: bytes,
aggregated_transcript: AggregatedTranscript,
public_key: DkgPublicKey,
request_encrypting_key: bytes,
request_encrypting_key: PublicKey,
transacting_power: TransactingPower,
) -> TxReceipt:
contract_function: ContractFunction = self.contract.functions.postAggregation(
ritualId=ritual_id,
aggregatedTranscript=aggregated_transcript,
aggregatedTranscript=bytes(aggregated_transcript),
publicKey=self.Ritual.G1Point.from_dkg_public_key(public_key),
requestEncryptingKey=request_encrypting_key,
requestEncryptingKey=request_encrypting_key.to_compressed_bytes(),
)
receipt = self.blockchain.send_transaction(
contract_function=contract_function,

View File

@ -1,6 +1,3 @@
from eth_tester import EthereumTester, PyEVMBackend
from eth_tester.backends.mock.main import MockBackend
from typing import Union

View File

@ -47,14 +47,12 @@ from nucypher_core import (
Conditions,
Context,
EncryptedKeyFrag,
EncryptedThresholdDecryptionResponse,
EncryptedTreasureMap,
MessageKit,
NodeMetadata,
NodeMetadataPayload,
ReencryptionResponse,
ThresholdDecryptionRequest,
ThresholdDecryptionResponse,
TreasureMap,
)
from nucypher_core.umbral import (
@ -699,14 +697,7 @@ class Bob(Character):
else ritual.shares
) # TODO: #3095 get this from the ritual / put it on-chain?
request_encrypting_keys = {}
participants = ritual.participants
for p in participants:
# TODO don't use Umbral in the long-run
request_encrypting_keys[p.provider] = PublicKey.from_compressed_bytes(
p.requestEncryptingKey
)
request_encrypting_keys = ritual.request_encrypting_keys
decryption_shares = self.gather_decryption_shares(
ritual_id=ritual_id,
cohort=ursulas,

View File

@ -2,6 +2,7 @@ import os
import pytest
from eth_utils import keccak
from nucypher_core.umbral import SecretKey
from nucypher.blockchain.eth.agents import (
ContractAgency,
@ -25,11 +26,6 @@ def transcripts():
return [os.urandom(32), os.urandom(32)]
@pytest.fixture(scope='module')
def aggregated_transcript():
return os.urandom(32)
@pytest.fixture(scope="module")
def cohort(testerchain, staking_providers):
deployer, cohort_provider_1, cohort_provider_2, *everybody_else = staking_providers
@ -126,18 +122,20 @@ def test_post_transcript(agent, transcripts, transacting_powers):
def test_post_aggregation(
agent, aggregated_transcript, dkg_public_key, transacting_powers
agent, aggregated_transcript, dkg_public_key, transacting_powers, cohort
):
ritual_id = agent.number_of_rituals() - 1
request_encrypting_keys = [os.urandom(32) for t in transacting_powers]
request_encrypting_keys = {}
for i, transacting_power in enumerate(transacting_powers):
request_encrypting_key = SecretKey.random().public_key()
receipt = agent.post_aggregation(
ritual_id=ritual_id,
aggregated_transcript=aggregated_transcript,
public_key=dkg_public_key,
request_encrypting_key=request_encrypting_keys[i],
request_encrypting_key=request_encrypting_key,
transacting_power=transacting_power,
)
request_encrypting_keys[cohort[i]] = request_encrypting_key
assert receipt["status"] == 1
post_aggregation_events = (
@ -147,13 +145,19 @@ def test_post_aggregation(
event = post_aggregation_events[0]
assert event["args"]["ritualId"] == ritual_id
assert event["args"]["aggregatedTranscriptDigest"] == keccak(
aggregated_transcript
bytes(aggregated_transcript)
)
participants = agent.get_participants(ritual_id)
for i, p in enumerate(participants):
for p in participants:
assert p.aggregated
assert p.requestEncryptingKey == request_encrypting_keys[i]
assert (
p.requestEncryptingKey
== request_encrypting_keys[p.provider].to_compressed_bytes()
)
ritual = agent.get_ritual(ritual_id)
assert ritual.request_encrypting_keys == request_encrypting_keys
assert agent.get_ritual_status(ritual_id=ritual_id) == agent.Ritual.Status.FINALIZED

View File

@ -156,6 +156,3 @@ RPC_SUCCESSFUL_RESPONSE = {
"id": 1,
"result": "Geth/v1.9.20-stable-979fc968/linux-amd64/go1.15"
}
FAKE_TRANSCRIPT = b'\x98\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\xae\xdb_-\xeaj\x9bz\xdd\xd6\x98\xf8\xf91A\xc1\x8f;\x13@\x89\xcb\xcf>\x86\xc4T\xfb\x0c\x1ety\x8b\xd8mSkk\xbb\xcaU\xe5]v}E\xfa\xbc\xae\xb6\xa1\xf4e\x19\x86\xf2L\xcaZj\x03]h:\xbfP\x03Q\x8c\x95e\xe0c\xaa\xc2\xb4\xbby}\xecW%\xdet\xc8\xfc\xe7ky\xe5\xf6\xe9\xf5\x05\xe5\xdf\x81\x9bx\x18\xa4\x15\x85\xdeA9\x9f\x99\xceQ\xb0\xd0&\x9a\xa7\xaed&\x99\xdc\xa7\xfeLM\x01\x02\x87\xc8\x14$\x89"kA\x0b\x91\t\x1e\x1c/f\x00N,\x88\x01\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\xab\x0f\tFA\xdcB\xd4\xb3\x08\xd7IVkmw6za\xb6)\x13\x014]f.\xa1\xcd\xe27\xee\xc0\x95\xf6\xa4\x12\xa9\x19\x94\xed\x05\xffF\x81\xb2\xb2\xcb\x06\xaf-\xe4\xb5\x98\xbd\x81\x0f\xb8\xb7\xa1<\xf6/\xe5\xa4\x11\x83}\xfaH\x15\x80h\n\xe7\xc6\xc2\xb3\xd5{dH\xeb\x1e]v\xb4\x88v\x88\xb7N1\xff\x80\xd0\x88\x04.\x00\x82K\x1e\x96\xa0\xbd}X\xbb{?6\xeb\xe7\rg\x03\xeeG\x01\x10^\xee\x9cH\x94[\x9d8s\xa3\xb6\x8f\xfc\xf1\xdf\x01m\xf9\x08_N\xb5-\x16O\x89n\x95\xf3\x8b[\x1f&Yk?*\x07\x8fQ\x98\x85\xd5\xc1YL\xe0CB\xb2"!\x8d,\x90Q7\xca\x9c\x0e\xb2\x7f\xb0\xe1\xc8\xdd\xe7\xe1\xe4\x14\xb3\xa6\xb4\x8e\x8b\xed\xacM\xc3\x9d\xc4|U\x93k\x17\xac\x14\x86\x16\xd7\xebk\xbd{\xad}\x87\x13Y\x83\x9d\x88\x1e\x1b4\xa7r\xa6\x80\xbf\xf0\x15\x99\x11Q\xdb\xeb\xdf\x15ns\xc6\x85\xb3\x1d\xf5j\xc5\x87`=OD\x86\x86\x08\x8d\xb6\x0b\xec\x1d\x15\xc9\x93\x9a\xed\xa3\xe2\x96\xa4\xa2b\xa6\xa5h\xb0\xbb4\xb3\x0c\xa5\xdcu\x1f{\xb9\xaf\xd0W\xe1\xa3&\xa8\xb5\xea\xe5c\xfd\xc7?\xbdLg\xb3\xae\xb9\xb8*\xfc\xd5\xa6\xeeI\x15v\xdc\xa2`1VZ\xb5\x1c_`\x86\xbe{\xef\xae\t\xf2\xa9N\x00\x9a\xa1F\x84\xb2\xe3\xbc\xfa\xf7I\xee\xe8[~\x99;i\xfc%\xa8\x80\x80\x8e%\'\x9c+\x9c\xa9\x13R!\x80w\xc0\xda[\x84\xf6X\xfe\xc2\xe3\x0f\x94-\xbb`\x00\x00\x00\x00\x00\x00\x00\x93\xff\x1e\x1b\x15;e\xfe}\x83v K\xf9\r\xc9\xad\x9d\xddN\xcd\xcaWq\xfa\x8e\x98sn\x9b~t\x01 =p\xe5\xb1\x7f"!\xb4\xb9\xc9W\x90\x86\x80\x17\nm\xa0\x8dD\xb5\xaf\xfc\xa5\xf5%V]\xb9\x89a@\xe5\x0c@#%x\xecW\xed\xb0a\x98\x1a!C\x80B@{\xf0\xffJ{\xa3\xeayDP\'u'

View File

@ -6,15 +6,16 @@ import tempfile
from datetime import timedelta
from functools import partial
from pathlib import Path
from typing import Tuple
import maya
import pytest
from click.testing import CliRunner
from eth_account import Account
from eth_utils import to_checksum_address
from ferveo_py.ferveo_py import DkgPublicKey
from ferveo_py.ferveo_py import AggregatedTranscript, DkgPublicKey, DkgPublicParameters
from ferveo_py.ferveo_py import Keypair as FerveoKeyPair
from ferveo_py.ferveo_py import Validator
from ferveo_py.ferveo_py import Transcript, Validator
from twisted.internet.task import Clock
from web3 import Web3
@ -380,7 +381,7 @@ def log_in_and_out_of_test(request):
test_logger.info(f"Finalized {module_name}.py::{test_name}")
@pytest.fixture(scope='module')
@pytest.fixture(scope="session")
def get_random_checksum_address():
def _get_random_checksum_address():
canonical_address = os.urandom(20)
@ -699,8 +700,10 @@ def ursulas(testerchain, staking_providers, ursula_test_config):
_ursulas.clear()
@pytest.fixture(scope="module")
def dkg_public_key(get_random_checksum_address) -> DkgPublicKey:
@pytest.fixture(scope="session")
def dkg_public_key_data(
get_random_checksum_address,
) -> Tuple[AggregatedTranscript, DkgPublicKey, DkgPublicParameters]:
ritual_id = 0
num_shares = 4
threshold = 3
@ -726,7 +729,7 @@ def dkg_public_key(get_random_checksum_address) -> DkgPublicKey:
)
transcripts.append(transcript)
_, public_key, _ = dkg.aggregate_transcripts(
aggregate_transcript, public_key, params = dkg.aggregate_transcripts(
ritual_id=ritual_id,
me=validators[0],
shares=num_shares,
@ -734,4 +737,16 @@ def dkg_public_key(get_random_checksum_address) -> DkgPublicKey:
transcripts=list(zip(validators, transcripts)),
)
return public_key
return aggregate_transcript, public_key, params
@pytest.fixture(scope="session")
def dkg_public_key(dkg_public_key_data) -> DkgPublicKey:
_, dkg_public_key, _ = dkg_public_key_data
return dkg_public_key
@pytest.fixture(scope="session")
def aggregated_transcript(dkg_public_key_data) -> AggregatedTranscript:
aggregated_transcript, _, _ = dkg_public_key_data
return aggregated_transcript

View File

@ -1,10 +1,11 @@
import time
from enum import Enum
from typing import Dict, List, Union
from typing import Dict, List
from eth_typing import ChecksumAddress
from eth_utils import keccak
from ferveo_py.ferveo_py import DkgPublicKey
from ferveo_py.ferveo_py import AggregatedTranscript, DkgPublicKey, Transcript
from nucypher_core.umbral import PublicKey
from web3.types import TxReceipt
from nucypher.blockchain.eth.agents import CoordinatorAgent
@ -80,10 +81,10 @@ class MockCoordinatorAgent(MockContractAgent):
return self.blockchain.FAKE_RECEIPT
def post_transcript(
self,
ritual_id: int,
transcript: bytes,
transacting_power: TransactingPower
self,
ritual_id: int,
transcript: Transcript,
transacting_power: TransactingPower,
) -> TxReceipt:
ritual = self.rituals[ritual_id]
operator_address = transacting_power.account
@ -93,7 +94,7 @@ class MockCoordinatorAgent(MockContractAgent):
or transacting_power.account
)
participant = self.get_participant_from_provider(ritual_id, provider)
participant.transcript = transcript
participant.transcript = bytes(transcript)
ritual.total_transcripts += 1
if ritual.total_transcripts == ritual.dkg_size:
ritual.status = self.RitualStatus.AWAITING_AGGREGATIONS
@ -109,9 +110,9 @@ class MockCoordinatorAgent(MockContractAgent):
def post_aggregation(
self,
ritual_id: int,
aggregated_transcript: bytes,
aggregated_transcript: AggregatedTranscript,
public_key: DkgPublicKey,
request_encrypting_key: bytes,
request_encrypting_key: PublicKey,
transacting_power: TransactingPower,
) -> TxReceipt:
ritual = self.rituals[ritual_id]
@ -123,15 +124,15 @@ class MockCoordinatorAgent(MockContractAgent):
)
participant = self.get_participant_from_provider(ritual_id, provider)
participant.aggregated = True
participant.requestEncryptingKey = request_encrypting_key
participant.requestEncryptingKey = request_encrypting_key.to_compressed_bytes()
g1_point = self.Ritual.G1Point.from_dkg_public_key(public_key)
if len(ritual.aggregated_transcript) == 0:
ritual.aggregated_transcript = aggregated_transcript
ritual.aggregated_transcript = bytes(aggregated_transcript)
ritual.public_key = g1_point
elif bytes(ritual.public_key) != bytes(g1_point) or keccak(
ritual.aggregated_transcript
) != keccak(aggregated_transcript):
) != keccak(bytes(aggregated_transcript)):
ritual.aggregation_mismatch = True
# don't increment aggregations
# TODO Emit EndRitual here?

View File

@ -1,10 +1,13 @@
import pytest
from ferveo_py.ferveo_py import Keypair as FerveoKeyPair
from ferveo_py.ferveo_py import Validator
from nucypher.blockchain.economics import EconomicsFactory
from nucypher.blockchain.eth.actors import Operator
from nucypher.blockchain.eth.agents import ContractAgency
from nucypher.blockchain.eth.interfaces import BlockchainInterfaceFactory
from nucypher.blockchain.eth.registry import InMemoryContractRegistry
from nucypher.crypto.ferveo import dkg
from nucypher.crypto.powers import TransactingPower
from nucypher.network.nodes import Teacher
from tests.mock.interfaces import MockBlockchain, MockEthereumClient
@ -88,3 +91,30 @@ def mock_substantiate_stamp(module_mocker, monkeymodule):
module_mocker.patch.object(Ursula, "_substantiate_stamp", autospec=True)
module_mocker.patch.object(Ursula, "operator_signature", fake_signature)
module_mocker.patch.object(Teacher, "validate_operator")
@pytest.fixture(scope="session")
def random_transcript(get_random_checksum_address):
ritual_id = 0
num_shares = 4
threshold = 3
validators = []
for i in range(0, num_shares):
validators.append(
Validator(
address=get_random_checksum_address(),
public_key=FerveoKeyPair.random().public_key(),
)
)
validators.sort(key=lambda x: x.address) # must be sorte
transcript = dkg.generate_transcript(
ritual_id=ritual_id,
me=validators[0],
shares=num_shares,
threshold=threshold,
nodes=validators,
)
return transcript

View File

@ -1,11 +1,10 @@
import os
from collections import OrderedDict
from unittest.mock import Mock
import pytest
from eth_account import Account
from nucypher_core.umbral import SecretKey
from tests.constants import FAKE_TRANSCRIPT
from tests.mock.coordinator import MockCoordinatorAgent
from tests.mock.interfaces import MockBlockchain
@ -59,7 +58,9 @@ def test_mock_coordinator_initiation(mocker, nodes_transacting_powers, coordinat
assert set(signal_data["participants"]) == nodes_transacting_powers.keys()
def test_mock_coordinator_round_1(nodes_transacting_powers, coordinator):
def test_mock_coordinator_round_1(
nodes_transacting_powers, coordinator, random_transcript
):
ritual = coordinator.rituals[0]
assert (
coordinator.get_ritual_status(0)
@ -70,7 +71,7 @@ def test_mock_coordinator_round_1(nodes_transacting_powers, coordinator):
assert p.transcript == bytes()
for index, node_address in enumerate(nodes_transacting_powers):
transcript = FAKE_TRANSCRIPT
transcript = random_transcript
coordinator.post_transcript(
ritual_id=0,
@ -79,7 +80,7 @@ def test_mock_coordinator_round_1(nodes_transacting_powers, coordinator):
)
performance = ritual.participants[index]
assert performance.transcript == transcript
assert performance.transcript == bytes(transcript)
if index == len(nodes_transacting_powers) - 1:
assert len(coordinator.EVENTS) == 2
@ -91,7 +92,11 @@ def test_mock_coordinator_round_1(nodes_transacting_powers, coordinator):
def test_mock_coordinator_round_2(
nodes_transacting_powers, coordinator, dkg_public_key
nodes_transacting_powers,
coordinator,
aggregated_transcript,
dkg_public_key,
random_transcript,
):
ritual = coordinator.rituals[0]
assert (
@ -100,12 +105,11 @@ def test_mock_coordinator_round_2(
)
for p in ritual.participants:
assert p.transcript == FAKE_TRANSCRIPT
assert p.transcript == bytes(random_transcript)
aggregated_transcript = os.urandom(len(FAKE_TRANSCRIPT))
request_encrypting_keys = []
for index, node_address in enumerate(nodes_transacting_powers):
request_encrypting_key = os.urandom(32)
request_encrypting_key = SecretKey.random().public_key()
coordinator.post_aggregation(
ritual_id=0,
aggregated_transcript=aggregated_transcript,
@ -117,14 +121,17 @@ def test_mock_coordinator_round_2(
if index == len(nodes_transacting_powers) - 1:
assert len(coordinator.EVENTS) == 2
assert ritual.aggregated_transcript == aggregated_transcript
assert ritual.aggregated_transcript == bytes(aggregated_transcript)
assert bytes(ritual.public_key) == bytes(dkg_public_key)
for index, p in enumerate(ritual.participants):
# unchanged
assert p.transcript == FAKE_TRANSCRIPT
assert p.transcript != aggregated_transcript
assert p.requestEncryptingKey == request_encrypting_keys[index]
assert p.transcript == bytes(random_transcript)
assert p.transcript != bytes(aggregated_transcript)
assert (
p.requestEncryptingKey
== request_encrypting_keys[index].to_compressed_bytes()
)
assert len(coordinator.EVENTS) == 2 # no additional event emitted here?
assert (

View File

@ -3,7 +3,6 @@ import pytest
from nucypher.blockchain.eth.agents import CoordinatorAgent
from nucypher.blockchain.eth.signers.software import Web3Signer
from nucypher.crypto.powers import TransactingPower
from tests.constants import FAKE_TRANSCRIPT
from tests.mock.coordinator import MockCoordinatorAgent
@ -84,10 +83,12 @@ def test_perform_round_1(ursula, random_address, cohort, agent):
)
def test_perform_round_2(ursula, cohort, transacting_power, agent, mocker):
def test_perform_round_2(
ursula, cohort, transacting_power, agent, mocker, random_transcript
):
participants = [
CoordinatorAgent.Ritual.Participant(
provider=c, aggregated=False, transcript=FAKE_TRANSCRIPT
provider=c, aggregated=False, transcript=bytes(random_transcript)
)
for c in cohort
]