- atx integration test
- eliminate 'fire and forget'
- warp space-time
pull/3475/head
KPrasch 2024-01-30 16:08:58 +01:00 committed by derekpierre
parent 8eef3b1f2c
commit 416164a846
No known key found for this signature in database
27 changed files with 1036 additions and 925 deletions

View File

@ -0,0 +1 @@
Introduces automated protocol transaction retries

View File

@ -6,8 +6,8 @@ from decimal import Decimal
from typing import DefaultDict, Dict, List, Optional, Set, Tuple, Union
import maya
from atxm.tx import AsyncTx
from eth_typing import ChecksumAddress
from hexbytes import HexBytes
from nucypher_core import (
EncryptedThresholdDecryptionRequest,
EncryptedThresholdDecryptionResponse,
@ -26,7 +26,6 @@ from nucypher_core.ferveo import (
Validator,
)
from web3 import HTTPProvider, Web3
from web3.exceptions import TransactionNotFound
from web3.types import TxReceipt
from nucypher.acumen.nicknames import Nickname
@ -57,6 +56,7 @@ from nucypher.datastore.dkg import DKGStorage
from nucypher.policy.conditions.evm import _CONDITION_CHAINS
from nucypher.policy.conditions.utils import evaluate_condition_lingo
from nucypher.policy.payment import ContractPayment
from nucypher.types import PhaseId, RitualId
from nucypher.utilities.emitters import StdoutEmitter
from nucypher.utilities.logging import Logger
@ -78,15 +78,17 @@ class BaseActor:
checksum_address: Optional[ChecksumAddress] = None,
):
if not (bool(checksum_address) ^ bool(transacting_power)):
error = f'Pass transacting power or checksum address, got {checksum_address} and {transacting_power}.'
error = f"Pass transacting power or checksum address, got {checksum_address} and {transacting_power}."
raise ValueError(error)
try:
parent_address = self.checksum_address
if checksum_address is not None:
if parent_address != checksum_address:
raise ValueError(f"Can't have two different ethereum addresses. "
f"Got {parent_address} and {checksum_address}.")
raise ValueError(
f"Can't have two different ethereum addresses. "
f"Got {parent_address} and {checksum_address}."
)
except AttributeError:
if transacting_power:
self.checksum_address = transacting_power.account
@ -114,9 +116,11 @@ class BaseActor:
@property
def eth_balance(self) -> Decimal:
"""Return this actor's current ETH balance"""
blockchain = BlockchainInterfaceFactory.get_interface() # TODO: EthAgent? #1509
blockchain = (
BlockchainInterfaceFactory.get_interface()
) # TODO: EthAgent? #1509
balance = blockchain.client.get_balance(self.wallet_address)
return Web3.from_wei(balance, 'ether')
return Web3.from_wei(balance, "ether")
@property
def wallet_address(self):
@ -223,8 +227,7 @@ class Operator(BaseActor):
self.publish_finalization = (
publish_finalization # publish the DKG final key if True
)
# TODO: #3052 stores locally generated public DKG artifacts
self.dkg_storage = DKGStorage()
self.ritual_power = crypto_power.power_ups(
RitualisticPower
) # ferveo material contained within
@ -236,6 +239,8 @@ class Operator(BaseActor):
condition_blockchain_endpoints
)
self.dkg_storage = DKGStorage()
def set_provider_public_key(self) -> Union[TxReceipt, None]:
# TODO: Here we're assuming there is one global key per node. See nucypher/#3167
node_global_ferveo_key_set = self.coordinator_agent.is_provider_public_key_set(
@ -319,7 +324,7 @@ class Operator(BaseActor):
# Local
external_validator = Validator(
address=self.checksum_address,
public_key=self.ritual_power.public_key()
public_key=self.ritual_power.public_key(),
)
else:
# Remote
@ -341,78 +346,40 @@ class Operator(BaseActor):
return result
def publish_transcript(self, ritual_id: int, transcript: Transcript) -> HexBytes:
tx_hash = self.coordinator_agent.post_transcript(
def publish_transcript(self, ritual_id: int, transcript: Transcript) -> AsyncTx:
async_tx = self.coordinator_agent.post_transcript(
ritual_id=ritual_id,
transcript=transcript,
transacting_power=self.transacting_power,
)
return tx_hash
identifier = self._phase_id(ritual_id, PHASE1)
self.ritual_tracker.active_rituals[identifier] = async_tx
return async_tx
def publish_aggregated_transcript(
self,
ritual_id: int,
aggregated_transcript: AggregatedTranscript,
public_key: DkgPublicKey,
) -> HexBytes:
) -> AsyncTx:
"""Publish an aggregated transcript to publicly available storage."""
# look up the node index for this node on the blockchain
participant_public_key = self.threshold_request_power.get_pubkey_from_ritual_id(
ritual_id
)
tx_hash = self.coordinator_agent.post_aggregation(
tx = self.coordinator_agent.post_aggregation(
ritual_id=ritual_id,
aggregated_transcript=aggregated_transcript,
public_key=public_key,
participant_public_key=participant_public_key,
transacting_power=self.transacting_power,
)
return tx_hash
def get_phase_receipt(
self, ritual_id: int, phase: int
) -> Tuple[Optional[HexBytes], Optional[TxReceipt]]:
if phase == 1:
txhash = self.dkg_storage.get_transcript_txhash(ritual_id=ritual_id)
elif phase == 2:
txhash = self.dkg_storage.get_aggregation_txhash(ritual_id=ritual_id)
else:
raise ValueError(f"Invalid phase: '{phase}'.")
if not txhash:
return None, None
try:
blockchain = self.coordinator_agent.blockchain.client
receipt = blockchain.get_transaction_receipt(txhash)
except TransactionNotFound:
return txhash, None
# at least for now (pre dkg tracker) - clear since receipt obtained
if phase == 1:
self.dkg_storage.clear_transcript_txhash(ritual_id, txhash)
else:
self.dkg_storage.clear_aggregated_txhash(ritual_id, txhash)
status = receipt.get("status")
if status == 1:
return txhash, receipt
else:
return None, None
def _phase_has_pending_tx(self, ritual_id: int, phase: int) -> bool:
tx_hash, _ = self.get_phase_receipt(ritual_id=ritual_id, phase=phase)
if not tx_hash:
return False
self.log.info(
f"Node {self.transacting_power.account} has pending tx {bytes(tx_hash).hex()} "
f"for ritual #{ritual_id}, phase #{phase}; skipping execution"
)
return True
identifier = (RitualId(ritual_id), PHASE2)
self.ritual_tracker.active_rituals[identifier] = tx
return tx
def _is_phase_1_action_required(self, ritual_id: int) -> bool:
"""Check whether node needs to perform a DKG round 1 action."""
# handle pending transactions
if self._phase_has_pending_tx(ritual_id=ritual_id, phase=PHASE1):
return False
# check ritual status from the blockchain
status = self.coordinator_agent.get_ritual_status(ritual_id=ritual_id)
@ -440,27 +407,74 @@ class Operator(BaseActor):
return True
@staticmethod
def _phase_id(ritual_id: int, phase: int) -> Tuple[RitualId, PhaseId]:
return RitualId(ritual_id), PhaseId(phase)
def perform_round_1(
self,
ritual_id: int,
authority: ChecksumAddress,
participants: List[ChecksumAddress],
timestamp: int,
) -> Optional[HexBytes]:
"""Perform round 1 of the DKG protocol for a given ritual ID on this node."""
) -> Optional[AsyncTx]:
"""
Perform phase 1 of the DKG protocol for a given ritual ID on this node.
This method is idempotent and will not submit a transcript if one has
already been submitted. It is dispatched by the EventActuator when it
receives a StartRitual event from the blockchain. Since the EventActuator
scans overlapping blocks, it is possible that this method will be called
multiple times for the same ritual. This method will check the state of
the ritual and participant on the blockchain before submitting a transcript.
If a there is a tracked AsyncTx for the given ritual and round
combination, this method will return the tracked transaction. If there is
no tracked transaction, this method will submit a transcript and return the
resulting FutureTx.
Returning None indicates that no action was required or taken.
Errors raised by this method are not explicitly caught and are expected
to be handled by the EventActuator.
"""
if self.checksum_address not in participants:
# should never get here
self.log.error(
f"Not part of ritual {ritual_id}; no need to submit transcripts"
)
raise RuntimeError(
f"Not participating in ritual {ritual_id}; should not have been notified"
# ERROR: This is an abnormal state since this method
# is designed to be invoked only when this node
# is an on-chain participant in the Coordinator.StartRitual event.
#
# This is a *nearly* a critical error. It's possible that the
# log it as an error and return None to avoid crashing upstack
# async tasks and drawing unnecessary amounts of attention to the issue.
message = (
f"{self.checksum_address}|{self.wallet_address} "
f"is not a member of ritual {ritual_id}"
)
self.log.error(message)
return
# check phase 1 contract state
if not self._is_phase_1_action_required(ritual_id=ritual_id):
self.log.debug(
"No action required for phase 1 of DKG protocol for some reason or another."
)
return
# check if there is already pending tx for this ritual + round combination
async_tx = self.ritual_tracker.active_rituals.get(
self._phase_id(ritual_id, PHASE1)
)
if async_tx:
self.log.info(
f"Active ritual in progress: {self.transacting_power.account} has submitted tx"
f"for ritual #{ritual_id}, phase #{PHASE1} (final: {async_tx.final})"
)
return async_tx
#
# Perform phase 1 of the DKG protocol
#
ritual = self.coordinator_agent.get_ritual(
ritual_id=ritual_id,
transcripts=False,
@ -490,9 +504,8 @@ class Operator(BaseActor):
raise e
# publish the transcript and store the receipt
tx_hash = self.publish_transcript(ritual_id=ritual.id, transcript=transcript)
self.dkg_storage.store_transcript_txhash(ritual_id=ritual.id, txhash=tx_hash)
self.dkg_storage.store_validators(ritual_id=ritual.id, validators=validators)
async_tx = self.publish_transcript(ritual_id=ritual.id, transcript=transcript)
# logging
arrival = ritual.total_transcripts + 1
@ -500,13 +513,10 @@ class Operator(BaseActor):
f"{self.transacting_power.account[:8]} submitted a transcript for "
f"DKG ritual #{ritual.id} ({arrival}/{ritual.dkg_size}) with authority {authority}."
)
return tx_hash
return async_tx
def _is_phase_2_action_required(self, ritual_id: int) -> bool:
"""Check whether node needs to perform a DKG round 2 action."""
# check if there is a pending tx for this ritual + round combination
if self._phase_has_pending_tx(ritual_id=ritual_id, phase=PHASE2):
return False
# check ritual status from the blockchain
status = self.coordinator_agent.get_ritual_status(ritual_id=ritual_id)
@ -535,12 +545,23 @@ class Operator(BaseActor):
return True
def perform_round_2(self, ritual_id: int, timestamp: int) -> Optional[HexBytes]:
def perform_round_2(self, ritual_id: int, timestamp: int) -> Optional[AsyncTx]:
"""Perform round 2 of the DKG protocol for the given ritual ID on this node."""
# check phase 2 state
if not self._is_phase_2_action_required(ritual_id=ritual_id):
return
# check if there is a pending tx for this ritual + round combination
async_tx = self.ritual_tracker.active_rituals.get(
self._phase_id(ritual_id, PHASE2)
)
if async_tx:
self.log.info(
f"Active ritual in progress Node {self.transacting_power.account} has submitted tx"
f"for ritual #{ritual_id}, phase #{PHASE1} (final: {async_tx.final})."
)
return async_tx
ritual = self.coordinator_agent.get_ritual(
ritual_id=ritual_id,
transcripts=True,
@ -556,14 +577,15 @@ class Operator(BaseActor):
transcripts = (Transcript.from_bytes(bytes(t)) for t in ritual.transcripts)
messages = list(zip(validators, transcripts))
try:
aggregated_transcript, dkg_public_key = (
self.ritual_power.aggregate_transcripts(
threshold=ritual.threshold,
shares=ritual.shares,
checksum_address=self.checksum_address,
ritual_id=ritual.id,
transcripts=messages,
)
(
aggregated_transcript,
dkg_public_key,
) = self.ritual_power.aggregate_transcripts(
threshold=ritual.threshold,
shares=ritual.shares,
checksum_address=self.checksum_address,
ritual_id=ritual.id,
transcripts=messages,
)
except Exception as e:
self.log.debug(
@ -573,15 +595,12 @@ class Operator(BaseActor):
# publish the transcript with network-wide jitter to avoid tx congestion
time.sleep(random.randint(0, self.AGGREGATION_SUBMISSION_MAX_DELAY))
tx_hash = self.publish_aggregated_transcript(
async_tx = self.publish_aggregated_transcript(
ritual_id=ritual.id,
aggregated_transcript=aggregated_transcript,
public_key=dkg_public_key,
)
# store the txhash
self.dkg_storage.store_aggregation_txhash(ritual_id=ritual.id, txhash=tx_hash)
# logging
total = ritual.total_aggregations + 1
self.log.debug(
@ -590,7 +609,8 @@ class Operator(BaseActor):
)
if total >= ritual.dkg_size:
self.log.debug(f"DKG ritual #{ritual.id} should now be finalized")
return tx_hash
return async_tx
def derive_decryption_share(
self,
@ -600,13 +620,10 @@ class Operator(BaseActor):
variant: FerveoVariant,
) -> Union[DecryptionShareSimple, DecryptionSharePrecomputed]:
ritual = self._resolve_ritual(ritual_id)
validators = self._resolve_validators(ritual)
aggregated_transcript = AggregatedTranscript.from_bytes(
bytes(ritual.aggregated_transcript)
)
decryption_share = self.ritual_power.derive_decryption_share(
nodes=validators,
threshold=ritual.threshold,
@ -616,7 +633,7 @@ class Operator(BaseActor):
aggregated_transcript=aggregated_transcript,
ciphertext_header=ciphertext_header,
aad=aad,
variant=variant
variant=variant,
)
return decryption_share

View File

@ -13,10 +13,10 @@ from typing import (
Optional,
Tuple,
Type,
Union,
cast,
)
from atxm.tx import AsyncTx
from constant_sorrow.constants import (
# type: ignore
CONTRACT_CALL,
@ -24,7 +24,6 @@ from constant_sorrow.constants import (
)
from eth_typing.evm import ChecksumAddress
from eth_utils import to_checksum_address, to_int
from hexbytes import HexBytes
from nucypher_core import SessionStaticKey
from nucypher_core.ferveo import (
AggregatedTranscript,
@ -46,7 +45,7 @@ from nucypher.blockchain.eth.constants import (
)
from nucypher.blockchain.eth.decorators import contract_api
from nucypher.blockchain.eth.interfaces import BlockchainInterfaceFactory
from nucypher.blockchain.eth.models import Coordinator, Ferveo
from nucypher.blockchain.eth.models import PHASE1, PHASE2, Coordinator, Ferveo
from nucypher.blockchain.eth.registry import (
ContractRegistry,
)
@ -68,7 +67,7 @@ class EthereumContractAgent:
# TODO - #842: Gas Management
DEFAULT_TRANSACTION_GAS_LIMITS: Dict[str, Optional[Wei]]
DEFAULT_TRANSACTION_GAS_LIMITS = {'default': None}
DEFAULT_TRANSACTION_GAS_LIMITS = {"default": None}
class ContractNotDeployed(Exception):
"""Raised when attempting to access a contract that is not deployed on the current network."""
@ -86,7 +85,6 @@ class EthereumContractAgent:
contract: Optional[Contract] = None,
transaction_gas: Optional[Wei] = None,
):
self.log = Logger(self.__class__.__name__)
self.registry = registry
@ -103,7 +101,9 @@ class EthereumContractAgent:
self.__contract = contract
self.events = events.ContractEvents(contract)
if not transaction_gas:
transaction_gas = EthereumContractAgent.DEFAULT_TRANSACTION_GAS_LIMITS['default']
transaction_gas = EthereumContractAgent.DEFAULT_TRANSACTION_GAS_LIMITS[
"default"
]
self.transaction_gas = transaction_gas
self.log.info(
@ -133,7 +133,6 @@ class EthereumContractAgent:
class NucypherTokenAgent(EthereumContractAgent):
contract_name: str = NUCYPHER_TOKEN_CONTRACT_NAME
@contract_api(CONTRACT_CALL)
@ -158,9 +157,12 @@ class NucypherTokenAgent(EthereumContractAgent):
increase: types.NuNits,
) -> TxReceipt:
"""Increase the allowance of a spender address funded by a sender address"""
contract_function: ContractFunction = self.contract.functions.increaseAllowance(spender_address, increase)
receipt: TxReceipt = self.blockchain.send_transaction(contract_function=contract_function,
transacting_power=transacting_power)
contract_function: ContractFunction = self.contract.functions.increaseAllowance(
spender_address, increase
)
receipt: TxReceipt = self.blockchain.send_transaction(
contract_function=contract_function, transacting_power=transacting_power
)
return receipt
@contract_api(TRANSACTION)
@ -171,9 +173,12 @@ class NucypherTokenAgent(EthereumContractAgent):
decrease: types.NuNits,
) -> TxReceipt:
"""Decrease the allowance of a spender address funded by a sender address"""
contract_function: ContractFunction = self.contract.functions.decreaseAllowance(spender_address, decrease)
receipt: TxReceipt = self.blockchain.send_transaction(contract_function=contract_function,
transacting_power=transacting_power)
contract_function: ContractFunction = self.contract.functions.decreaseAllowance(
spender_address, decrease
)
receipt: TxReceipt = self.blockchain.send_transaction(
contract_function=contract_function, transacting_power=transacting_power
)
return receipt
@contract_api(TRANSACTION)
@ -186,11 +191,17 @@ class NucypherTokenAgent(EthereumContractAgent):
"""Approve the spender address to transfer an amount of tokens on behalf of the sender address"""
self._validate_zero_allowance(amount, spender_address, transacting_power)
payload: TxParams = {'gas': Wei(500_000)} # TODO #842: gas needed for use with geth! <<<< Is this still open?
contract_function: ContractFunction = self.contract.functions.approve(spender_address, amount)
receipt: TxReceipt = self.blockchain.send_transaction(contract_function=contract_function,
payload=payload,
transacting_power=transacting_power)
payload: TxParams = {
"gas": Wei(500_000)
} # TODO #842: gas needed for use with geth! <<<< Is this still open?
contract_function: ContractFunction = self.contract.functions.approve(
spender_address, amount
)
receipt: TxReceipt = self.blockchain.send_transaction(
contract_function=contract_function,
payload=payload,
transacting_power=transacting_power,
)
return receipt
@contract_api(TRANSACTION)
@ -201,9 +212,12 @@ class NucypherTokenAgent(EthereumContractAgent):
transacting_power: TransactingPower,
) -> TxReceipt:
"""Transfer an amount of tokens from the sender address to the target address."""
contract_function: ContractFunction = self.contract.functions.transfer(target_address, amount)
receipt: TxReceipt = self.blockchain.send_transaction(contract_function=contract_function,
transacting_power=transacting_power)
contract_function: ContractFunction = self.contract.functions.transfer(
target_address, amount
)
receipt: TxReceipt = self.blockchain.send_transaction(
contract_function=contract_function, transacting_power=transacting_power
)
return receipt
@contract_api(TRANSACTION)
@ -219,23 +233,30 @@ class NucypherTokenAgent(EthereumContractAgent):
payload = None
if gas_limit: # TODO: Gas management - #842
payload = {'gas': gas_limit}
approve_and_call: ContractFunction = self.contract.functions.approveAndCall(target_address, amount, call_data)
approve_and_call_receipt: TxReceipt = self.blockchain.send_transaction(contract_function=approve_and_call,
transacting_power=transacting_power,
payload=payload)
payload = {"gas": gas_limit}
approve_and_call: ContractFunction = self.contract.functions.approveAndCall(
target_address, amount, call_data
)
approve_and_call_receipt: TxReceipt = self.blockchain.send_transaction(
contract_function=approve_and_call,
transacting_power=transacting_power,
payload=payload,
)
return approve_and_call_receipt
def _validate_zero_allowance(self, amount, target_address, transacting_power):
if amount == 0:
return
current_allowance = self.get_allowance(owner=transacting_power.account, spender=target_address)
current_allowance = self.get_allowance(
owner=transacting_power.account, spender=target_address
)
if current_allowance != 0:
raise self.RequirementError(f"Token allowance for spender {target_address} must be 0")
raise self.RequirementError(
f"Token allowance for spender {target_address} must be 0"
)
class SubscriptionManagerAgent(EthereumContractAgent):
contract_name: str = SUBSCRIPTION_MANAGER_CONTRACT_NAME
class PolicyInfo(NamedTuple):
@ -267,7 +288,7 @@ class SubscriptionManagerAgent(EthereumContractAgent):
end_timestamp=record[2],
size=record[3],
# If the policyOwner addr is null, we return the sponsor addr instead of the owner.
owner=record[0] if record[4] == NULL_ADDRESS else record[4]
owner=record[0] if record[4] == NULL_ADDRESS else record[4],
)
return policy_info
@ -276,27 +297,25 @@ class SubscriptionManagerAgent(EthereumContractAgent):
#
@contract_api(TRANSACTION)
def create_policy(self,
policy_id: bytes,
transacting_power: TransactingPower,
size: int,
start_timestamp: Timestamp,
end_timestamp: Timestamp,
value: Wei,
owner_address: Optional[ChecksumAddress] = None) -> TxReceipt:
def create_policy(
self,
policy_id: bytes,
transacting_power: TransactingPower,
size: int,
start_timestamp: Timestamp,
end_timestamp: Timestamp,
value: Wei,
owner_address: Optional[ChecksumAddress] = None,
) -> TxReceipt:
owner_address = owner_address or transacting_power.account
payload: TxParams = {'value': value}
payload: TxParams = {"value": value}
contract_function: ContractFunction = self.contract.functions.createPolicy(
policy_id,
owner_address,
size,
start_timestamp,
end_timestamp
policy_id, owner_address, size, start_timestamp, end_timestamp
)
receipt = self.blockchain.send_transaction(
contract_function=contract_function,
payload=payload,
transacting_power=transacting_power
transacting_power=transacting_power,
)
return receipt
@ -566,7 +585,9 @@ class TACoApplicationAgent(StakerSamplingApplicationAgent):
self, staking_provider: ChecksumAddress
) -> StakingProviderInfo:
# remove reserved fields
info: list = self.contract.functions.stakingProviderInfo(staking_provider).call()
info: list = self.contract.functions.stakingProviderInfo(
staking_provider
).call()
return TACoApplicationAgent.StakingProviderInfo(*info[0:3])
@contract_api(CONTRACT_CALL)
@ -610,11 +631,19 @@ class TACoApplicationAgent(StakerSamplingApplicationAgent):
#
@contract_api(TRANSACTION)
def bond_operator(self, staking_provider: ChecksumAddress, operator: ChecksumAddress, transacting_power: TransactingPower) -> TxReceipt:
def bond_operator(
self,
staking_provider: ChecksumAddress,
operator: ChecksumAddress,
transacting_power: TransactingPower,
) -> TxReceipt:
"""For use by threshold operator accounts only."""
contract_function: ContractFunction = self.contract.functions.bondOperator(staking_provider, operator)
receipt = self.blockchain.send_transaction(contract_function=contract_function,
transacting_power=transacting_power)
contract_function: ContractFunction = self.contract.functions.bondOperator(
staking_provider, operator
)
receipt = self.blockchain.send_transaction(
contract_function=contract_function, transacting_power=transacting_power
)
return receipt
@ -796,17 +825,16 @@ class CoordinatorAgent(EthereumContractAgent):
ritual_id: int,
transcript: Transcript,
transacting_power: TransactingPower,
fire_and_forget: bool = True,
) -> Union[TxReceipt, HexBytes]:
) -> AsyncTx:
contract_function: ContractFunction = self.contract.functions.postTranscript(
ritualId=ritual_id, transcript=bytes(transcript)
)
receipt = self.blockchain.send_transaction(
atx = self.blockchain.send_async_transaction(
contract_function=contract_function,
transacting_power=transacting_power,
fire_and_forget=fire_and_forget,
info={"ritual_id": ritual_id, "phase": PHASE1},
)
return receipt
return atx
@contract_api(TRANSACTION)
def post_aggregation(
@ -816,23 +844,22 @@ class CoordinatorAgent(EthereumContractAgent):
public_key: DkgPublicKey,
participant_public_key: SessionStaticKey,
transacting_power: TransactingPower,
fire_and_forget: bool = True,
) -> Union[TxReceipt, HexBytes]:
) -> AsyncTx:
contract_function: ContractFunction = self.contract.functions.postAggregation(
ritualId=ritual_id,
aggregatedTranscript=bytes(aggregated_transcript),
dkgPublicKey=Ferveo.G1Point.from_dkg_public_key(public_key),
decryptionRequestStaticKey=bytes(participant_public_key),
)
receipt = self.blockchain.send_transaction(
atx = self.blockchain.send_async_transaction(
contract_function=contract_function,
gas_estimation_multiplier=1.4,
transacting_power=transacting_power,
fire_and_forget=fire_and_forget,
info={"ritual_id": ritual_id, "phase": PHASE2},
)
return receipt
return atx
@contract_api(TRANSACTION)
@contract_api(CONTRACT_CALL)
def get_ritual_initiation_cost(
self, providers: List[ChecksumAddress], duration: int
) -> Wei:
@ -841,7 +868,7 @@ class CoordinatorAgent(EthereumContractAgent):
).call()
return Wei(result)
@contract_api(TRANSACTION)
@contract_api(CONTRACT_CALL)
def get_ritual_id_from_public_key(self, public_key: DkgPublicKey) -> int:
g1_point = Ferveo.G1Point.from_dkg_public_key(public_key)
result = self.contract.functions.getRitualIdFromPublicKey(g1_point).call()
@ -865,7 +892,9 @@ class ContractAgency:
"""Where agents live and die."""
# TODO: Enforce singleton - #1506 - Okay, actually, make this into a module
__agents: Dict[str, Dict[Type[EthereumContractAgent], EthereumContractAgent]] = dict()
__agents: Dict[str, Dict[Type[EthereumContractAgent], EthereumContractAgent]] = (
dict()
)
@classmethod
def get_agent(
@ -926,7 +955,7 @@ class ContractAgency:
agent_class=agent_class,
registry=registry,
blockchain_endpoint=blockchain_endpoint,
contract_version=contract_version
contract_version=contract_version,
)
return agent
@ -959,7 +988,9 @@ class WeightedSampler:
return []
if quantity > len(self):
raise ValueError("Cannot sample more than the total amount of elements without replacement")
raise ValueError(
"Cannot sample more than the total amount of elements without replacement"
)
samples = []
@ -984,7 +1015,6 @@ class WeightedSampler:
class StakingProvidersReservoir:
def __init__(self, staking_provider_map: Dict[ChecksumAddress, int]):
self._sampler = WeightedSampler(staking_provider_map)
self._rng = random.SystemRandom()

View File

@ -6,6 +6,7 @@ from typing import Union
from constant_sorrow.constants import UNKNOWN_DEVELOPMENT_CHAIN_ID
from cytoolz.dicttoolz import dissoc
from eth_account import Account
from eth_account.datastructures import SignedTransaction
from eth_account.messages import encode_defunct
from eth_typing.evm import BlockNumber, ChecksumAddress
from eth_utils import to_canonical_address, to_checksum_address
@ -555,13 +556,12 @@ class EthereumTesterClient(EthereumClient):
raise self.UnknownAccount(account)
return signing_key
def sign_transaction(self, transaction_dict: dict) -> bytes:
def sign_transaction(self, transaction_dict: dict) -> SignedTransaction:
# Sign using a local private key
address = to_canonical_address(transaction_dict['from'])
signing_key = self.__get_signing_key(account=address)
signed_transaction = self.w3.eth.account.sign_transaction(transaction_dict, private_key=signing_key)
rlp_transaction = signed_transaction.rawTransaction
return rlp_transaction
return signed_transaction
def sign_message(self, account: str, message: bytes) -> str:
"""Sign, EIP-191 (Geth) Style"""

View File

@ -5,6 +5,7 @@ from typing import Callable, Dict, NamedTuple, Optional, Union
from urllib.parse import urlparse
import requests
from atxm import AutomaticTxMachine
from constant_sorrow.constants import (
INSUFFICIENT_FUNDS,
NO_BLOCKCHAIN_CONNECTION,
@ -24,7 +25,7 @@ from web3.contract.contract import Contract, ContractConstructor, ContractFuncti
from web3.exceptions import TimeExhausted
from web3.middleware import geth_poa_middleware, simple_cache_middleware
from web3.providers import BaseProvider
from web3.types import TxReceipt
from web3.types import TxParams, TxReceipt
from nucypher.blockchain.eth.clients import POA_CHAINS, EthereumClient, InfuraClient
from nucypher.blockchain.eth.decorators import validate_checksum_address
@ -47,7 +48,9 @@ from nucypher.utilities.gas_strategies import (
)
from nucypher.utilities.logging import Logger
Web3Providers = Union[IPCProvider, WebsocketProvider, HTTPProvider, EthereumTester] # TODO: Move to types.py
Web3Providers = Union[
IPCProvider, WebsocketProvider, HTTPProvider, EthereumTester
] # TODO: Move to types.py
class BlockchainInterface:
@ -58,7 +61,7 @@ class BlockchainInterface:
TIMEOUT = 600 # seconds # TODO: Correlate with the gas strategy - #2070
DEFAULT_GAS_STRATEGY = 'fast'
DEFAULT_GAS_STRATEGY = "fast"
GAS_STRATEGIES = WEB3_GAS_STRATEGIES
Web3 = Web3 # TODO: This is name-shadowing the actual Web3. Is this intentional?
@ -85,15 +88,15 @@ class BlockchainInterface:
}
class TransactionFailed(InterfaceError):
IPC_CODE = -32000
def __init__(self,
message: str,
transaction_dict: dict,
contract_function: Union[ContractFunction, ContractConstructor],
*args):
def __init__(
self,
message: str,
transaction_dict: dict,
contract_function: Union[ContractFunction, ContractConstructor],
*args,
):
self.base_message = message
self.name = get_transaction_name(contract_function=contract_function)
self.payload = transaction_dict
@ -107,28 +110,32 @@ class BlockchainInterface:
@property
def default(self) -> str:
sender = self.payload["from"]
message = f'{self.name} from {sender[:6]}... \n' \
f'Sender balance: {prettify_eth_amount(self.get_balance())} \n' \
f'Reason: {self.base_message} \n' \
f'Transaction: {self.payload}'
message = (
f"{self.name} from {sender[:6]}... \n"
f"Sender balance: {prettify_eth_amount(self.get_balance())} \n"
f"Reason: {self.base_message} \n"
f"Transaction: {self.payload}"
)
return message
def get_balance(self):
blockchain = BlockchainInterfaceFactory.get_interface()
balance = blockchain.client.get_balance(account=self.payload['from'])
balance = blockchain.client.get_balance(account=self.payload["from"])
return balance
@property
def insufficient_funds(self) -> str:
try:
transaction_fee = self.payload['gas'] * self.payload['gasPrice']
transaction_fee = self.payload["gas"] * self.payload["gasPrice"]
except KeyError:
return self.default
else:
cost = transaction_fee + self.payload.get('value', 0)
message = f'{self.name} from {self.payload["from"][:8]} - {self.base_message}.' \
f'Calculated cost is {prettify_eth_amount(cost)},' \
f'but sender only has {prettify_eth_amount(self.get_balance())}.'
cost = transaction_fee + self.payload.get("value", 0)
message = (
f'{self.name} from {self.payload["from"][:8]} - {self.base_message}.'
f"Calculated cost is {prettify_eth_amount(cost)},"
f"but sender only has {prettify_eth_amount(self.get_balance())}."
)
return message
def __init__(
@ -204,16 +211,23 @@ class BlockchainInterface:
"""
self.log = Logger('Blockchain')
self.log = Logger("Blockchain")
self.poa = poa
self.endpoint = endpoint
self._provider = provider
self.w3 = NO_BLOCKCHAIN_CONNECTION
self.client: EthereumClient = NO_BLOCKCHAIN_CONNECTION
self.is_light = light
self.tx_machine = AutomaticTxMachine(w3=self.w3)
# TODO: Not ready to give users total flexibility. Let's stick for the moment to known values. See #2447
if gas_strategy not in ('slow', 'medium', 'fast', 'free', None): # FIXME: What is 'None' doing here?
if gas_strategy not in (
"slow",
"medium",
"fast",
"free",
None,
): # FIXME: What is 'None' doing here?
raise ValueError(f"'{gas_strategy}' is an invalid gas strategy")
self.gas_strategy = gas_strategy or self.DEFAULT_GAS_STRATEGY
self.max_gas_price = max_gas_price
@ -241,7 +255,9 @@ class BlockchainInterface:
except KeyError:
if gas_strategy:
if not callable(gas_strategy):
raise ValueError(f"{gas_strategy} must be callable to be a valid gas strategy.")
raise ValueError(
f"{gas_strategy} must be callable to be a valid gas strategy."
)
else:
gas_strategy = cls.GAS_STRATEGIES[cls.DEFAULT_GAS_STRATEGY]
return gas_strategy
@ -256,7 +272,7 @@ class BlockchainInterface:
# For use with Proof-Of-Authority test-blockchains
if self.poa is True:
self.log.debug('Injecting POA middleware at layer 0')
self.log.debug("Injecting POA middleware at layer 0")
self.client.inject_middleware(geth_poa_middleware, layer=0)
self.log.debug("Adding simple_cache_middleware")
@ -266,7 +282,6 @@ class BlockchainInterface:
# self.configure_gas_strategy()
def configure_gas_strategy(self, gas_strategy: Optional[Callable] = None) -> None:
if gas_strategy:
reported_gas_strategy = f"fixed/{gas_strategy.name}"
@ -281,8 +296,10 @@ class BlockchainInterface:
configuration_message = f"Using gas strategy '{reported_gas_strategy}'"
if self.max_gas_price:
__price = Web3.to_wei(self.max_gas_price, 'gwei') # from gwei to wei
gas_strategy = max_price_gas_strategy_wrapper(gas_strategy=gas_strategy, max_gas_price_wei=__price)
__price = Web3.to_wei(self.max_gas_price, "gwei") # from gwei to wei
gas_strategy = max_price_gas_strategy_wrapper(
gas_strategy=gas_strategy, max_gas_price_wei=__price
)
configuration_message += f", with a max price of {self.max_gas_price} gwei."
self.client.set_gas_strategy(gas_strategy=gas_strategy)
@ -295,7 +312,6 @@ class BlockchainInterface:
# self.log.debug(f"Gas strategy currently reports a gas price of {gwei_gas_price} gwei.")
def connect(self):
endpoint = self.endpoint
self.log.info(f"Using external Web3 Provider '{self.endpoint}'")
@ -311,6 +327,7 @@ class BlockchainInterface:
# Connect if not connected
try:
self.w3 = self.Web3(provider=self._provider)
self.tx_machine.w3 = self.w3 # share this web3 instance with the tracker
self.client = EthereumClient.from_w3(w3=self.w3)
except requests.ConnectionError: # RPC
raise self.ConnectionFailed(
@ -344,22 +361,22 @@ class BlockchainInterface:
if endpoint and not provider:
uri_breakdown = urlparse(endpoint)
if uri_breakdown.scheme == 'tester':
if uri_breakdown.scheme == "tester":
providers = {
'pyevm': _get_pyevm_test_provider,
'mock': _get_mock_test_provider
"pyevm": _get_pyevm_test_provider,
"mock": _get_mock_test_provider,
}
provider_scheme = uri_breakdown.netloc
else:
providers = {
'auto': _get_auto_provider,
'ipc': _get_IPC_provider,
'file': _get_IPC_provider,
'ws': _get_websocket_provider,
'wss': _get_websocket_provider,
'http': _get_HTTP_provider,
'https': _get_HTTP_provider,
"auto": _get_auto_provider,
"ipc": _get_IPC_provider,
"file": _get_IPC_provider,
"ws": _get_websocket_provider,
"wss": _get_websocket_provider,
"http": _get_HTTP_provider,
"https": _get_HTTP_provider,
}
provider_scheme = uri_breakdown.scheme
@ -384,12 +401,13 @@ class BlockchainInterface:
self._provider = provider
@classmethod
def _handle_failed_transaction(cls,
exception: Exception,
transaction_dict: dict,
contract_function: Union[ContractFunction, ContractConstructor],
logger: Logger = None
) -> None:
def _handle_failed_transaction(
cls,
exception: Exception,
transaction_dict: dict,
contract_function: Union[ContractFunction, ContractConstructor],
logger: Logger = None,
) -> None:
"""
Re-raising error handler and context manager for transaction broadcast or
build failure events at the interface layer. This method is a last line of defense
@ -401,8 +419,8 @@ class BlockchainInterface:
# Assume this error is formatted as an RPC response
try:
code = int(response['code'])
message = response['message']
code = int(response["code"])
message = response["message"]
except Exception:
# TODO: #1504 - Try even harder to determine if this is insufficient funds causing the issue,
# This may be best handled at the agent or actor layer for registry and token interactions.
@ -417,12 +435,16 @@ class BlockchainInterface:
if logger:
logger.critical(message) # simple context
transaction_failed = cls.TransactionFailed(message=message, # rich error (best case)
contract_function=contract_function,
transaction_dict=transaction_dict)
transaction_failed = cls.TransactionFailed(
message=message, # rich error (best case)
contract_function=contract_function,
transaction_dict=transaction_dict,
)
raise transaction_failed from exception
def __log_transaction(self, transaction_dict: dict, contract_function: ContractFunction):
def __log_transaction(
self, transaction_dict: dict, contract_function: ContractFunction
):
"""
Format and log a transaction dict and return the transaction name string.
This method *must not* mutate the original transaction dict.
@ -431,30 +453,38 @@ class BlockchainInterface:
tx = dict(transaction_dict).copy()
# Format
if tx.get('to'):
tx['to'] = to_checksum_address(contract_function.address)
if tx.get("to"):
tx["to"] = to_checksum_address(contract_function.address)
try:
tx['selector'] = contract_function.selector
tx["selector"] = contract_function.selector
except AttributeError:
pass
tx['from'] = to_checksum_address(tx['from'])
tx.update({f: prettify_eth_amount(v) for f, v in tx.items() if f in ('gasPrice', 'value')})
payload_pprint = ', '.join("{}: {}".format(k, v) for k, v in tx.items())
tx["from"] = to_checksum_address(tx["from"])
tx.update(
{
f: prettify_eth_amount(v)
for f, v in tx.items()
if f in ("gasPrice", "value")
}
)
payload_pprint = ", ".join("{}: {}".format(k, v) for k, v in tx.items())
# Log
transaction_name = get_transaction_name(contract_function=contract_function)
self.log.debug(f"[TX-{transaction_name}] | {payload_pprint}")
@validate_checksum_address
def build_payload(self,
sender_address: str,
payload: dict = None,
transaction_gas_limit: int = None,
use_pending_nonce: bool = True,
) -> dict:
nonce = self.client.get_transaction_count(account=sender_address, pending=use_pending_nonce)
base_payload = {'nonce': nonce, 'from': sender_address}
def build_payload(
self,
sender_address: str,
payload: dict = None,
transaction_gas_limit: int = None,
use_pending_nonce: bool = True,
) -> dict:
nonce = self.client.get_transaction_count(
account=sender_address, pending=use_pending_nonce
)
base_payload = {"nonce": nonce, "from": sender_address}
# Aggregate
if not payload:
@ -462,55 +492,72 @@ class BlockchainInterface:
payload.update(base_payload)
# Explicit gas override - will skip gas estimation in next operation.
if transaction_gas_limit:
payload['gas'] = int(transaction_gas_limit)
payload["gas"] = int(transaction_gas_limit)
return payload
@validate_checksum_address
def build_contract_transaction(self,
contract_function: ContractFunction,
sender_address: str,
payload: dict = None,
transaction_gas_limit: Optional[int] = None,
gas_estimation_multiplier: Optional[float] = None,
use_pending_nonce: Optional[bool] = None,
) -> dict:
def build_contract_transaction(
self,
contract_function: ContractFunction,
sender_address: str,
payload: dict = None,
transaction_gas_limit: Optional[int] = None,
gas_estimation_multiplier: Optional[float] = None,
use_pending_nonce: Optional[bool] = None,
log_now: bool = True,
) -> TxParams:
if transaction_gas_limit is not None:
self.log.warn("The transaction gas limit of {transaction_gas_limit} will override gas estimation attempts")
self.log.warn(
f"The transaction gas limit of {transaction_gas_limit} will override gas estimation attempts"
)
# Sanity checks for the gas estimation multiplier
if gas_estimation_multiplier is not None:
if not 1 <= gas_estimation_multiplier <= 3: # Arbitrary upper bound.
raise ValueError(f"The gas estimation multiplier should be a float between 1 and 3, "
f"but we received {gas_estimation_multiplier}.")
raise ValueError(
f"The gas estimation multiplier must be a float between 1 and 3, "
f"but we received {gas_estimation_multiplier}."
)
payload = self.build_payload(sender_address=sender_address,
payload=payload,
transaction_gas_limit=transaction_gas_limit,
use_pending_nonce=use_pending_nonce)
self.__log_transaction(transaction_dict=payload, contract_function=contract_function)
payload = self.build_payload(
sender_address=sender_address,
payload=payload,
transaction_gas_limit=transaction_gas_limit,
use_pending_nonce=use_pending_nonce,
)
if log_now:
self.__log_transaction(
transaction_dict=payload, contract_function=contract_function
)
try:
if 'gas' not in payload: # i.e., transaction_gas_limit is not None
if "gas" not in payload: # i.e., transaction_gas_limit is not None
# As web3 build_transaction() will estimate gas with block identifier "pending" by default,
# explicitly estimate gas here with block identifier 'latest' if not otherwise specified
# as a pending transaction can cause gas estimation to fail, notably in case of worklock refunds.
payload['gas'] = contract_function.estimate_gas(payload, block_identifier='latest')
payload["gas"] = contract_function.estimate_gas(
payload, block_identifier="latest"
)
transaction_dict = contract_function.build_transaction(payload)
except (TestTransactionFailed, ValidationError, ValueError) as error:
# Note: Geth (1.9.15) raises ValueError in the same condition that pyevm raises ValidationError here.
# Treat this condition as "Transaction Failed" during gas estimation.
raise self._handle_failed_transaction(exception=error,
transaction_dict=payload,
contract_function=contract_function,
logger=self.log)
raise self._handle_failed_transaction(
exception=error,
transaction_dict=payload,
contract_function=contract_function,
logger=self.log,
)
# Increase the estimated gas limit according to the gas estimation multiplier, if any.
if gas_estimation_multiplier and not transaction_gas_limit:
gas_estimation = transaction_dict['gas']
gas_estimation = transaction_dict["gas"]
overestimation = int(math.ceil(gas_estimation * gas_estimation_multiplier))
self.log.debug(f"Gas limit for this TX was increased from {gas_estimation} to {overestimation}, "
f"using a multiplier of {gas_estimation_multiplier}.")
transaction_dict['gas'] = overestimation
self.log.debug(
f"Gas limit for this TX was increased from {gas_estimation} to {overestimation}, "
f"using a multiplier of {gas_estimation_multiplier}."
)
transaction_dict["gas"] = overestimation
# TODO: What if we're going over the block limit? Not likely, but perhaps worth checking (NRN)
return transaction_dict
@ -521,141 +568,167 @@ class BlockchainInterface:
transaction_dict: Dict,
transaction_name: str = "",
confirmations: int = 0,
fire_and_forget: bool = False,
) -> Union[TxReceipt, HexBytes]:
"""
Takes a transaction dictionary, signs it with the configured signer, then broadcasts the signed
transaction using the ethereum provider's eth_sendRawTransaction RPC endpoint.
Optionally blocks for receipt and confirmation with 'confirmations', and 'fire_and_forget' flags.
If 'fire and forget' is True this method returns the transaction hash only, without waiting for a receipt -
otherwise return the transaction receipt.
Takes a transaction dictionary, signs it with the configured signer,
then broadcasts the signed transaction using the RPC provider's
eth_sendRawTransaction endpoint.
"""
#
# Setup
#
# TODO # 1754 - Move this to singleton - I do not approve... nor does Bogdan?
emitter = StdoutEmitter()
#
# Sign
#
# TODO: Show the USD Price: https://api.coinmarketcap.com/v1/ticker/ethereum/
try:
# post-london fork transactions (Type 2)
max_unit_price = transaction_dict['maxFeePerGas']
tx_type = 'EIP-1559'
max_unit_price = transaction_dict["maxFeePerGas"]
tx_type = "EIP-1559"
except KeyError:
# pre-london fork "legacy" transactions (Type 0)
max_unit_price = transaction_dict['gasPrice']
tx_type = 'Legacy'
max_unit_price = transaction_dict["gasPrice"]
tx_type = "Legacy"
max_price_gwei = Web3.from_wei(max_unit_price, 'gwei')
max_cost_wei = max_unit_price * transaction_dict['gas']
max_cost = Web3.from_wei(max_cost_wei, 'ether')
max_price_gwei = Web3.from_wei(max_unit_price, "gwei")
max_cost_wei = max_unit_price * transaction_dict["gas"]
max_cost = Web3.from_wei(max_cost_wei, "ether")
if transacting_power.is_device:
emitter.message(f'Confirm transaction {transaction_name} on hardware wallet... '
f'({max_cost} ETH @ {max_price_gwei} gwei)',
color='yellow')
signed_raw_transaction = transacting_power.sign_transaction(transaction_dict)
emitter.message(
f"Confirm transaction {transaction_name} on hardware wallet... "
f"({max_cost} ETH @ {max_price_gwei} gwei)",
color="yellow",
)
signed_transaction = transacting_power.sign_transaction(transaction_dict)
#
# Broadcast
#
emitter.message(f'Broadcasting {transaction_name} {tx_type} Transaction ({max_cost} ETH @ {max_price_gwei} gwei)',
color='yellow')
emitter.message(
f"Broadcasting {transaction_name} {tx_type} Transaction ({max_cost} ETH @ {max_price_gwei} gwei)",
color="yellow",
)
try:
txhash = self.client.send_raw_transaction(signed_raw_transaction) # <--- BROADCAST
emitter.message(f'TXHASH {txhash.hex()}', color='yellow')
txhash = self.client.send_raw_transaction(
signed_transaction.rawTransaction
) # <--- BROADCAST
emitter.message(f"TXHASH {txhash.hex()}", color="yellow")
except (TestTransactionFailed, ValueError):
raise # TODO: Unify with Transaction failed handling -- Entry point for _handle_failed_transaction
else:
if fire_and_forget:
return txhash
#
# Receipt
#
try: # TODO: Handle block confirmation exceptions
waiting_for = 'receipt'
try:
waiting_for = "receipt"
if confirmations:
waiting_for = f'{confirmations} confirmations'
emitter.message(f'Waiting {self.TIMEOUT} seconds for {waiting_for}', color='yellow')
receipt = self.client.wait_for_receipt(txhash, timeout=self.TIMEOUT, confirmations=confirmations)
waiting_for = f"{confirmations} confirmations"
emitter.message(
f"Waiting {self.TIMEOUT} seconds for {waiting_for}", color="yellow"
)
receipt = self.client.wait_for_receipt(
txhash, timeout=self.TIMEOUT, confirmations=confirmations
)
except TimeExhausted:
# TODO: #1504 - Handle transaction timeout
raise
else:
self.log.debug(f"[RECEIPT-{transaction_name}] | txhash: {receipt['transactionHash'].hex()}")
self.log.debug(
f"[RECEIPT-{transaction_name}] | txhash: {receipt['transactionHash'].hex()}"
)
#
# Confirmations
#
# Primary check
transaction_status = receipt.get('status', UNKNOWN_TX_STATUS)
transaction_status = receipt.get("status", UNKNOWN_TX_STATUS)
if transaction_status == 0:
failure = f"Transaction transmitted, but receipt returned status code 0. " \
f"Full receipt: \n {pprint.pformat(receipt, indent=2)}"
failure = (
f"Transaction transmitted, but receipt returned status code 0. "
f"Full receipt: \n {pprint.pformat(receipt, indent=2)}"
)
raise self.InterfaceError(failure)
if transaction_status is UNKNOWN_TX_STATUS:
self.log.info(f"Unknown transaction status for {txhash} (receipt did not contain a status field)")
self.log.info(
f"Unknown transaction status for {txhash} (receipt did not contain a status field)"
)
# Secondary check
tx = self.client.get_transaction(txhash)
if tx["gas"] == receipt["gasUsed"]:
raise self.InterfaceError(f"Transaction consumed 100% of transaction gas."
f"Full receipt: \n {pprint.pformat(receipt, indent=2)}")
raise self.InterfaceError(
f"Transaction consumed 100% of transaction gas."
f"Full receipt: \n {pprint.pformat(receipt, indent=2)}"
)
return receipt
def send_async_transaction(
self,
contract_function: ContractFunction,
transacting_power: TransactingPower,
transaction_gas_limit: Optional[int] = None,
gas_estimation_multiplier: float = 1.15,
info: Optional[Dict] = None,
payload: dict = None,
) -> TxReceipt:
transaction = self.build_contract_transaction(
contract_function=contract_function,
sender_address=transacting_power.account,
payload=payload,
transaction_gas_limit=transaction_gas_limit,
gas_estimation_multiplier=gas_estimation_multiplier,
log_now=False,
)
basic_info = {
"name": contract_function.fn_name,
"contract": contract_function.address,
}
if info:
basic_info.update(info)
# TODO: This is a bit of a hack. temporary solution until incoming PR #3382 is merged.
signer = transacting_power._signer._get_signer(transacting_power.account)
async_tx = self.tx_machine.queue_transaction(
info=info,
params=transaction,
signer=signer,
)
return async_tx
@validate_checksum_address
def send_transaction(self,
contract_function: Union[ContractFunction, ContractConstructor],
transacting_power: TransactingPower,
payload: dict = None,
transaction_gas_limit: Optional[int] = None,
gas_estimation_multiplier: Optional[float] = 1.15, # TODO: Workaround for #2635, #2337
confirmations: int = 0,
fire_and_forget: bool = False, # do not wait for receipt. See #2385
replace: bool = False,
) -> Union[TxReceipt, HexBytes]:
if fire_and_forget:
if confirmations > 0:
raise ValueError("Transaction Prevented: "
"Cannot use 'confirmations' and 'fire_and_forget' options together.")
use_pending_nonce = False # TODO: #2385
else:
use_pending_nonce = replace # TODO: #2385
transaction = self.build_contract_transaction(contract_function=contract_function,
sender_address=transacting_power.account,
payload=payload,
transaction_gas_limit=transaction_gas_limit,
gas_estimation_multiplier=gas_estimation_multiplier,
use_pending_nonce=use_pending_nonce)
# Get transaction name
def send_transaction(
self,
contract_function: Union[ContractFunction, ContractConstructor],
transacting_power: TransactingPower,
payload: dict = None,
transaction_gas_limit: Optional[int] = None,
gas_estimation_multiplier: Optional[
float
] = 1.15, # TODO: Workaround for #2635, #2337
) -> TxReceipt:
transaction = self.build_contract_transaction(
contract_function=contract_function,
sender_address=transacting_power.account,
payload=payload,
transaction_gas_limit=transaction_gas_limit,
gas_estimation_multiplier=gas_estimation_multiplier,
log_now=True,
)
try:
transaction_name = contract_function.fn_name.upper()
except AttributeError:
transaction_name = 'DEPLOY' if isinstance(contract_function, ContractConstructor) else 'UNKNOWN'
txhash_or_receipt = self.sign_and_broadcast_transaction(transacting_power=transacting_power,
transaction_dict=transaction,
transaction_name=transaction_name,
confirmations=confirmations,
fire_and_forget=fire_and_forget)
return txhash_or_receipt
transaction_name = (
"DEPLOY"
if isinstance(contract_function, ContractConstructor)
else "UNKNOWN"
)
receipt = self.sign_and_broadcast_transaction(
transacting_power=transacting_power,
transaction_dict=transaction,
transaction_name=transaction_name,
)
return receipt
def get_contract_by_name(
self,
@ -714,12 +787,9 @@ class BlockchainInterfaceFactory:
return bool(cls._interfaces.get(endpoint, False))
@classmethod
def register_interface(cls,
interface: BlockchainInterface,
emitter=None,
force: bool = False
) -> None:
def register_interface(
cls, interface: BlockchainInterface, emitter=None, force: bool = False
) -> None:
endpoint = interface.endpoint
if (endpoint in cls._interfaces) and not force:
raise cls.InterfaceAlreadyInitialized(
@ -763,7 +833,6 @@ class BlockchainInterfaceFactory:
@classmethod
def get_interface(cls, endpoint: str = None) -> Interfaces:
# Try to get an existing cached interface.
if endpoint:
try:

View File

@ -13,8 +13,10 @@ from nucypher_core.ferveo import (
FerveoPublicKey,
)
PHASE1 = 1
PHASE2 = 2
from nucypher.types import PhaseId
PHASE1 = PhaseId(1)
PHASE2 = PhaseId(2)
@dataclass

View File

@ -1,9 +1,8 @@
from abc import ABC, abstractmethod
from typing import List
from urllib.parse import urlparse
from eth_account.datastructures import SignedTransaction
from eth_typing.evm import ChecksumAddress
from hexbytes.main import HexBytes
@ -79,7 +78,7 @@ class Signer(ABC):
return NotImplemented
@abstractmethod
def sign_transaction(self, transaction_dict: dict) -> HexBytes:
def sign_transaction(self, transaction_dict: dict) -> SignedTransaction:
return NotImplemented
@abstractmethod

View File

@ -1,5 +1,3 @@
import json
from json.decoder import JSONDecodeError
from pathlib import Path
@ -8,9 +6,10 @@ from urllib.parse import urlparse
from cytoolz.dicttoolz import dissoc
from eth_account.account import Account
from eth_account.datastructures import SignedTransaction
from eth_account.messages import encode_defunct
from eth_account.signers.local import LocalAccount
from eth_utils.address import is_address, to_checksum_address
from eth_utils.address import is_address, to_canonical_address, to_checksum_address
from hexbytes.main import BytesLike, HexBytes
from nucypher.blockchain.eth.decorators import validate_checksum_address
@ -18,21 +17,28 @@ from nucypher.blockchain.eth.signers.base import Signer
class Web3Signer(Signer):
def __init__(self, client):
super().__init__()
self.__client = client
def _get_signer(self, account: str) -> LocalAccount:
"""Test helper to get a signer from the client's backend"""
account = to_canonical_address(account)
_eth_tester = self.__client.w3.provider.ethereum_tester
signer = Account.from_key(_eth_tester.backend._key_lookup[account]._raw_key)
return signer
@classmethod
def uri_scheme(cls) -> str:
return NotImplemented # web3 signer uses a "passthrough" scheme
@classmethod
def from_signer_uri(cls, uri: str, testnet: bool = False) -> 'Web3Signer':
def from_signer_uri(cls, uri: str, testnet: bool = False) -> "Web3Signer":
from nucypher.blockchain.eth.interfaces import (
BlockchainInterface,
BlockchainInterfaceFactory,
)
try:
blockchain = BlockchainInterfaceFactory.get_or_create_interface(
endpoint=uri
@ -57,9 +63,17 @@ class Web3Signer(Signer):
except AttributeError:
return False
else:
HW_WALLET_URL_PREFIXES = ('trezor', 'ledger')
hw_accounts = [w['accounts'] for w in wallets if w['url'].startswith(HW_WALLET_URL_PREFIXES)]
hw_addresses = [to_checksum_address(account['address']) for sublist in hw_accounts for account in sublist]
HW_WALLET_URL_PREFIXES = ("trezor", "ledger")
hw_accounts = [
w["accounts"]
for w in wallets
if w["url"].startswith(HW_WALLET_URL_PREFIXES)
]
hw_addresses = [
to_checksum_address(account["address"])
for sublist in hw_accounts
for account in sublist
]
return account in hw_addresses
@validate_checksum_address
@ -67,7 +81,9 @@ class Web3Signer(Signer):
if self.is_device(account=account):
unlocked = True
else:
unlocked = self.__client.unlock_account(account=account, password=password, duration=duration)
unlocked = self.__client.unlock_account(
account=account, password=password, duration=duration
)
return unlocked
@validate_checksum_address
@ -83,9 +99,11 @@ class Web3Signer(Signer):
signature = self.__client.sign_message(account=account, message=message)
return HexBytes(signature)
def sign_transaction(self, transaction_dict: dict) -> HexBytes:
signed_raw_transaction = self.__client.sign_transaction(transaction_dict=transaction_dict)
return signed_raw_transaction
def sign_transaction(self, transaction_dict: dict) -> SignedTransaction:
signed_transaction = self.__client.sign_transaction(
transaction_dict=transaction_dict
)
return signed_transaction
class KeystoreSigner(Signer):
@ -116,7 +134,7 @@ class KeystoreSigner(Signer):
@classmethod
def uri_scheme(cls) -> str:
return 'keystore'
return "keystore"
def __read_keystore(self, path: Path) -> None:
"""Read the keystore directory from the disk and populate accounts."""
@ -126,11 +144,15 @@ class KeystoreSigner(Signer):
elif path.is_file():
paths = (path,)
else:
raise self.InvalidSignerURI(f'Invalid keystore file or directory "{path}"')
raise self.InvalidSignerURI(
f'Invalid keystore file or directory "{path}"'
)
except FileNotFoundError:
raise self.InvalidSignerURI(f'No such keystore file or directory "{path}"')
except OSError as exc:
raise self.InvalidSignerURI(f'Error accessing keystore file or directory "{path}": {exc}')
raise self.InvalidSignerURI(
f'Error accessing keystore file or directory "{path}": {exc}'
)
for path in paths:
account, key_metadata = self.__handle_keyfile(path=path)
self.__keys[account] = key_metadata
@ -138,9 +160,9 @@ class KeystoreSigner(Signer):
@staticmethod
def __read_keyfile(path: Path) -> tuple:
"""Read an individual keystore key file from the disk"""
with open(path, 'r') as keyfile:
with open(path, "r") as keyfile:
key_metadata = json.load(keyfile)
address = key_metadata['address']
address = key_metadata["address"]
return address, key_metadata
def __handle_keyfile(self, path: Path) -> Tuple[str, dict]:
@ -171,7 +193,7 @@ class KeystoreSigner(Signer):
return address, key_metadata
@validate_checksum_address
def __get_signer(self, account: str) -> LocalAccount:
def _get_signer(self, account: str) -> LocalAccount:
"""Lookup a known keystore account by its checksum address or raise an error"""
try:
return self.__signers[account]
@ -191,14 +213,14 @@ class KeystoreSigner(Signer):
return self.__path
@classmethod
def from_signer_uri(cls, uri: str, testnet: bool = False) -> 'Signer':
"""Return a keystore signer from URI string i.e. keystore:///my/path/keystore """
def from_signer_uri(cls, uri: str, testnet: bool = False) -> "Signer":
"""Return a keystore signer from URI string i.e. keystore:///my/path/keystore"""
decoded_uri = urlparse(uri)
if decoded_uri.scheme != cls.uri_scheme() or decoded_uri.netloc:
raise cls.InvalidSignerURI(uri)
path = decoded_uri.path
if not path:
raise cls.InvalidSignerURI('Blank signer URI - No keystore path provided')
raise cls.InvalidSignerURI("Blank signer URI - No keystore path provided")
return cls(path=Path(path), testnet=testnet)
@validate_checksum_address
@ -228,10 +250,14 @@ class KeystoreSigner(Signer):
if not password:
# It is possible that password is None here passed from the above layer
# causing Account.decrypt to crash, expecting a value for password.
raise self.AuthenticationFailed('No password supplied to unlock account.')
raise self.AuthenticationFailed(
"No password supplied to unlock account."
)
raise
except ValueError as e:
raise self.AuthenticationFailed("Invalid or incorrect ethereum account password.") from e
raise self.AuthenticationFailed(
"Invalid or incorrect ethereum account password."
) from e
return True
@validate_checksum_address
@ -249,27 +275,31 @@ class KeystoreSigner(Signer):
return account not in self.__signers
@validate_checksum_address
def sign_transaction(self, transaction_dict: dict) -> HexBytes:
def sign_transaction(self, transaction_dict: dict) -> SignedTransaction:
"""
Produce a raw signed ethereum transaction signed by the account specified
in the 'from' field of the transaction dictionary.
"""
sender = transaction_dict['from']
signer = self.__get_signer(account=sender)
sender = transaction_dict["from"]
signer = self._get_signer(account=sender)
# TODO: Handle this at a higher level?
# Do not include a 'to' field for contract creation.
if not transaction_dict['to']:
transaction_dict = dissoc(transaction_dict, 'to')
if not transaction_dict["to"]:
transaction_dict = dissoc(transaction_dict, "to")
raw_transaction = signer.sign_transaction(transaction_dict=transaction_dict).rawTransaction
raw_transaction = signer.sign_transaction(
transaction_dict=transaction_dict
).rawTransaction
return raw_transaction
@validate_checksum_address
def sign_message(self, account: str, message: bytes, **kwargs) -> HexBytes:
signer = self.__get_signer(account=account)
signature = signer.sign_message(signable_message=encode_defunct(primitive=message)).signature
signer = self._get_signer(account=account)
signature = signer.sign_message(
signable_message=encode_defunct(primitive=message)
).signature
return HexBytes(signature)
@ -325,15 +355,15 @@ class InMemorySigner(Signer):
raise self.AccountLocked(account=account)
@validate_checksum_address
def sign_transaction(self, transaction_dict: dict) -> HexBytes:
def sign_transaction(self, transaction_dict: dict) -> SignedTransaction:
sender = transaction_dict["from"]
signer = self.__get_signer(account=sender)
if not transaction_dict["to"]:
transaction_dict = dissoc(transaction_dict, "to")
raw_transaction = signer.sign_transaction(
signed_transaction = signer.sign_transaction(
transaction_dict=transaction_dict
).rawTransaction
return raw_transaction
return signed_transaction
@validate_checksum_address
def sign_message(self, account: str, message: bytes, **kwargs) -> HexBytes:

View File

@ -14,7 +14,7 @@ class OperatorBondedTracker(SimpleTask):
def __init__(self, ursula):
self._ursula = ursula
super().__init__()
super().__init__(interval=self.INTERVAL)
def run(self) -> None:
application_agent = ContractAgency.get_agent(

View File

@ -1,16 +1,17 @@
import datetime
import os
import time
from typing import Callable, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple
import maya
from atxm.tx import AsyncTx, FutureTx
from prometheus_client import REGISTRY, Gauge
from twisted.internet import threads
from web3.datastructures import AttributeDict
from nucypher.blockchain.eth import actors
from nucypher.blockchain.eth.models import Coordinator
from nucypher.policy.conditions.utils import camel_case_to_snake
from nucypher.types import PhaseId, RitualId
from nucypher.utilities.cache import TTLCache
from nucypher.utilities.events import EventScanner, JSONifiedState
from nucypher.utilities.logging import Logger
@ -50,15 +51,17 @@ class EventScannerTask(SimpleTask):
INTERVAL = 120 # seconds
def __init__(self, scanner: Callable, *args, **kwargs):
def __init__(self, scanner: Callable):
self.scanner = scanner
super().__init__(*args, **kwargs)
super().__init__(interval=self.INTERVAL)
def run(self):
self.scanner()
def handle_errors(self, *args, **kwargs):
self.log.warn("Error during ritual event scanning: {}".format(args[0].getTraceback()))
self.log.warn(
"Error during ritual event scanning: {}".format(args[0].getTraceback())
)
if not self._task.running:
self.log.warn("Restarting event scanner task!")
self.start(now=False) # take a breather
@ -67,9 +70,7 @@ class EventScannerTask(SimpleTask):
class ActiveRitualTracker:
CHAIN_REORG_SCAN_WINDOW = 20
MAX_CHUNK_SIZE = 10000
MIN_CHUNK_SIZE = 60 # 60 blocks @ 2s per block on Polygon = 120s of blocks (somewhat related to interval)
# how often to check/purge for expired cached values - 8hrs?
@ -97,7 +98,7 @@ class ActiveRitualTracker:
def __init__(
self,
operator: "actors.Operator",
operator,
persistent: bool = False, # TODO: use persistent storage?
):
self.log = Logger("RitualTracker")
@ -122,6 +123,8 @@ class ActiveRitualTracker:
self.contract.events.EndRitual,
]
self.__phase_txs: Dict[Tuple[RitualId, PhaseId], FutureTx] = {}
# TODO: Remove the default JSON-RPC retry middleware
# as it correctly cannot handle eth_getLogs block range throttle down.
# self.web3.middleware_onion.remove(http_retry_request_middleware)
@ -164,6 +167,10 @@ class ActiveRitualTracker:
def contract(self):
return self.coordinator_agent.contract
@property
def active_rituals(self) -> Dict[Tuple[RitualId, PhaseId], AsyncTx]:
return self.__phase_txs
# TODO: should sample_window_size be additionally configurable/chain-dependent?
def _get_first_scan_start_block_number(self, sample_window_size: int = 100) -> int:
"""
@ -172,7 +179,7 @@ class ActiveRitualTracker:
w3 = self.web3
timeout = self.coordinator_agent.get_timeout()
latest_block = w3.eth.get_block('latest')
latest_block = w3.eth.get_block("latest")
if latest_block.number == 0:
return 0
@ -389,8 +396,10 @@ class ActiveRitualTracker:
camel_case_to_snake(k): v for k, v in ritual_event.args.items()
}
event_type = getattr(self.contract.events, ritual_event.event)
def task():
self.actions[event_type](timestamp=timestamp, **formatted_kwargs)
if defer:
d = threads.deferToThread(task)
d.addErrback(self.task.handle_errors)
@ -416,14 +425,18 @@ class ActiveRitualTracker:
def __scan(self, start_block, end_block, account):
# Run the scan
self.log.debug(f"({account[:8]}) Scanning events in block range {start_block} - {end_block}")
self.log.debug(
f"({account[:8]}) Scanning events in block range {start_block} - {end_block}"
)
start = time.time()
result, total_chunks_scanned = self.scanner.scan(start_block, end_block)
if self.persistent:
self.state.save()
duration = time.time() - start
self.log.debug(f"Scanned total of {len(result)} events, in {duration} seconds, "
f"total {total_chunks_scanned} chunk scans performed")
self.log.debug(
f"Scanned total of {len(result)} events, in {duration} seconds, "
f"total {total_chunks_scanned} chunk scans performed"
)
def scan(self):
"""

View File

@ -950,14 +950,14 @@ class Ursula(Teacher, Character, Operator):
preflight: bool = True,
block_until_ready: bool = True,
eager: bool = False,
transaction_tracking: bool = True,
) -> None:
"""Schedule and start select ursula services, then optionally start the reactor."""
# Connect to Provider
if not BlockchainInterfaceFactory.is_interface_initialized(
endpoint=self.eth_endpoint
):
BlockchainInterfaceFactory.initialize_interface(endpoint=self.eth_endpoint)
BlockchainInterfaceFactory.get_or_create_interface(endpoint=self.eth_endpoint)
polygon = BlockchainInterfaceFactory.get_or_create_interface(
endpoint=self.polygon_endpoint
)
if preflight:
self.__preflight()
@ -967,12 +967,19 @@ class Ursula(Teacher, Character, Operator):
#
if emitter:
emitter.message("Starting services", color="yellow")
emitter.message("Starting services...", color="yellow")
if discovery and not self.lonely:
self.start_learning_loop(now=eager)
if emitter:
emitter.message(f"✓ Node Discovery ({self.domain})", color="green")
emitter.message(f"✓ P2P Networking ({self.domain})", color="green")
if transaction_tracking:
# Uncomment to enable tracking for both chains.
# mainnet.tracker.start(now=False)
polygon.tx_machine.start(now=False)
if emitter:
emitter.message("✓ Transaction Autopilot", color="green")
if ritual_tracking:
self.ritual_tracker.start()
@ -1147,7 +1154,6 @@ class Ursula(Teacher, Character, Operator):
seed_uri = f"{seednode_metadata.checksum_address}@{seednode_metadata.rest_host}:{seednode_metadata.rest_port}"
return cls.from_seed_and_stake_info(seed_uri=seed_uri, *args, **kwargs)
@classmethod
def from_teacher_uri(
cls,

View File

@ -2,8 +2,8 @@ import inspect
from typing import List, Optional, Tuple, Union
from eth_account._utils.signing import to_standard_signature_bytes
from eth_account.datastructures import SignedTransaction
from eth_typing.evm import ChecksumAddress
from hexbytes import HexBytes
from nucypher_core import (
EncryptedThresholdDecryptionRequest,
EncryptedThresholdDecryptionResponse,
@ -201,7 +201,7 @@ class TransactingPower(CryptoPowerUp):
# from the recovery byte, bringing it to the standard choice of {0, 1}.
return to_standard_signature_bytes(signature)
def sign_transaction(self, transaction_dict: dict) -> HexBytes:
def sign_transaction(self, transaction_dict: dict) -> SignedTransaction:
"""Signs the transaction with the private key of the TransactingPower."""
return self._signer.sign_transaction(transaction_dict=transaction_dict)

View File

@ -1,9 +1,10 @@
from typing import NewType, TypeVar
from nucypher.blockchain.eth import agents
ERC20UNits = NewType("ERC20UNits", int)
NuNits = NewType("NuNits", ERC20UNits)
TuNits = NewType("TuNits", ERC20UNits)
Agent = TypeVar("Agent", bound="agents.EthereumContractAgent")
Agent = TypeVar("Agent", bound="agents.EthereumContractAgent") # noqa: F821
RitualId = int
PhaseId = int

View File

@ -9,14 +9,15 @@ from nucypher.utilities.logging import Logger
class SimpleTask(ABC):
"""Simple Twisted Looping Call abstract base class."""
INTERVAL = 60 # 60s default
INTERVAL = NotImplemented
CLOCK = reactor
def __init__(self, interval: float = None):
self.interval = interval or self.INTERVAL
self.log = Logger(self.__class__.__name__)
self._task = LoopingCall(self.run)
# self.__task.clock = self.CLOCK
self._task.clock = self.CLOCK
@property
def running(self) -> bool:
@ -28,7 +29,7 @@ class SimpleTask(ABC):
if not self.running:
d = self._task.start(interval=self.interval, now=now)
d.addErrback(self.handle_errors)
# return d
return d
def stop(self):
"""Stop task."""
@ -48,5 +49,5 @@ class SimpleTask(ABC):
@staticmethod
def clean_traceback(failure: Failure) -> str:
# FIXME: Amazing.
cleaned_traceback = failure.getTraceback().replace('{', '').replace('}', '')
cleaned_traceback = failure.getTraceback().replace("{", "").replace("}", "")
return cleaned_traceback

View File

@ -6,13 +6,11 @@ import pytest
import pytest_twisted
from hexbytes import HexBytes
from prometheus_client import REGISTRY
from twisted.internet.threads import deferToThread
from nucypher.blockchain.eth.agents import ContractAgency, SubscriptionManagerAgent
from nucypher.blockchain.eth.constants import NULL_ADDRESS
from nucypher.blockchain.eth.models import Coordinator
from nucypher.blockchain.eth.signers.software import Web3Signer
from nucypher.blockchain.eth.trackers.dkg import EventScannerTask
from nucypher.characters.lawful import Enrico, Ursula
from nucypher.policy.conditions.evm import ContractCondition, RPCCondition
from nucypher.policy.conditions.lingo import (
@ -24,18 +22,35 @@ from nucypher.policy.conditions.lingo import (
from nucypher.policy.conditions.time import TimeCondition
from tests.constants import TEST_ETH_PROVIDER_URI, TESTERCHAIN_CHAIN_ID
# constants
DKG_SIZE = 4
RITUAL_ID = 0
# This is a hack to make the tests run faster
EventScannerTask.INTERVAL = 1
TIME_TRAVEL_INTERVAL = 60
@pytest.fixture(scope="module")
def ritual_id():
return 0
# The message to encrypt and its conditions
PLAINTEXT = "peace at dawn"
DURATION = 48 * 60 * 60
@pytest.fixture(scope="module")
def dkg_size():
return 4
@pytest.fixture(scope="module")
def duration():
return 48 * 60 * 60
@pytest.fixture(scope="module")
def plaintext():
return "peace at dawn"
@pytest.fixture(scope="module")
def interval(testerchain):
return testerchain.tx_machine._task.interval
@pytest.fixture(scope="module")
def signer(testerchain):
return Web3Signer(testerchain.client)
@pytest.fixture(scope="module")
@ -94,282 +109,236 @@ def condition(test_registry):
return ConditionLingo(condition_to_use).to_dict()
@pytest.fixture(scope='module')
def cohort(ursulas):
"""Creates a cohort of Ursulas"""
nodes = list(sorted(ursulas[:DKG_SIZE], key=lambda x: int(x.checksum_address, 16)))
assert len(nodes) == DKG_SIZE # sanity check
@pytest.fixture(scope="module", autouse=True)
def transaction_tracker(testerchain, coordinator_agent):
testerchain.tx_machine.w3 = coordinator_agent.blockchain.w3
testerchain.tx_machine.start()
@pytest.fixture(scope="module")
def cohort(testerchain, clock, coordinator_agent, ursulas, dkg_size):
nodes = list(sorted(ursulas[:dkg_size], key=lambda x: int(x.checksum_address, 16)))
assert len(nodes) == dkg_size
for node in nodes:
node.ritual_tracker.task._task.clock = clock
node.ritual_tracker.start()
return nodes
@pytest_twisted.inlineCallbacks()
def test_ursula_ritualist(
condition,
testerchain,
@pytest.fixture(scope="module")
def threshold_message_kit(coordinator_agent, plaintext, condition, signer, ritual_id):
encrypting_key = coordinator_agent.get_ritual_public_key(ritual_id=ritual_id)
enrico = Enrico(encrypting_key=encrypting_key, signer=signer)
return enrico.encrypt_for_dkg(plaintext=plaintext.encode(), conditions=condition)
def test_dkg_initiation(
coordinator_agent,
global_allow_list,
cohort,
initiator,
bob,
ritual_token,
accounts,
initiator,
cohort,
global_allow_list,
testerchain,
ritual_token,
ritual_id,
duration,
):
"""Tests the DKG and the encryption/decryption of a message"""
signer = Web3Signer(client=testerchain.client)
print("==================== INITIALIZING ====================")
cohort_staking_provider_addresses = list(u.checksum_address for u in cohort)
# Round 0 - Initiate the ritual
def initialize():
"""Initiates the ritual"""
print("==================== INITIALIZING ====================")
cohort_staking_provider_addresses = list(u.checksum_address for u in cohort)
# Approve the ritual token for the coordinator agent to spend
amount = coordinator_agent.get_ritual_initiation_cost(
providers=cohort_staking_provider_addresses, duration=duration
)
ritual_token.approve(
coordinator_agent.contract_address,
amount,
sender=accounts[initiator.transacting_power.account],
)
# Approve the ritual token for the coordinator agent to spend
amount = coordinator_agent.get_ritual_initiation_cost(
providers=cohort_staking_provider_addresses, duration=DURATION
)
ritual_token.approve(
coordinator_agent.contract_address,
amount,
sender=accounts[initiator.transacting_power.account],
)
receipt = coordinator_agent.initiate_ritual(
providers=cohort_staking_provider_addresses,
authority=initiator.transacting_power.account,
duration=duration,
access_controller=global_allow_list.address,
transacting_power=initiator.transacting_power,
)
receipt = coordinator_agent.initiate_ritual(
providers=cohort_staking_provider_addresses,
authority=initiator.transacting_power.account,
duration=DURATION,
access_controller=global_allow_list.address,
transacting_power=initiator.transacting_power,
)
return receipt
testerchain.time_travel(seconds=1)
testerchain.wait_for_receipt(receipt["transactionHash"])
# Round 0 - Initiate the ritual
def check_initialize(receipt):
"""Checks the initialization of the ritual"""
print("==================== CHECKING INITIALIZATION ====================")
testerchain.wait_for_receipt(receipt['transactionHash'])
# check that the ritual was created on-chain
assert coordinator_agent.number_of_rituals() == ritual_id + 1
assert (
coordinator_agent.get_ritual_status(ritual_id)
== Coordinator.RitualStatus.DKG_AWAITING_TRANSCRIPTS
)
testerchain.wait_for_receipt(receipt["transactionHash"])
# check that the ritual was created on-chain
assert coordinator_agent.number_of_rituals() == RITUAL_ID + 1
@pytest_twisted.inlineCallbacks
def test_dkg_finality(
coordinator_agent, ritual_id, cohort, clock, interval, testerchain
):
print("==================== AWAITING DKG FINALITY ====================")
while (
coordinator_agent.get_ritual_status(ritual_id)
!= Coordinator.RitualStatus.ACTIVE
):
yield clock.advance(interval)
yield testerchain.time_travel(seconds=1)
testerchain.tx_machine.stop()
assert not testerchain.tx_machine.running
status = coordinator_agent.get_ritual_status(ritual_id)
assert status == Coordinator.RitualStatus.ACTIVE
status = coordinator_agent.get_ritual_status(ritual_id)
assert status == Coordinator.RitualStatus.ACTIVE
last_scanned_block = REGISTRY.get_sample_value(
"ritual_events_last_scanned_block_number"
)
assert last_scanned_block > 0
yield
def test_transcript_publication(coordinator_agent, cohort, ritual_id, dkg_size):
print("==================== VERIFYING DKG FINALITY ====================")
for ursula in cohort:
assert (
coordinator_agent.get_ritual_status(RITUAL_ID)
== Coordinator.RitualStatus.DKG_AWAITING_TRANSCRIPTS
)
# time travel has a side effect of mining a block so that the scanner will definitively
# pick up ritual event
testerchain.time_travel(seconds=1)
for ursula in cohort:
# this is a testing hack to make the event scanner work
# normally it's called by the reactor clock in a loop
ursula.ritual_tracker.task.run()
# nodes received `StartRitual` and submitted their transcripts
assert (
len(
coordinator_agent.get_participant(
ritual_id=RITUAL_ID,
provider=ursula.checksum_address,
transcript=True,
).transcript
)
> 0
), "ursula posted transcript to Coordinator"
def block_until_dkg_finalized(_):
"""simulates the passage of time and the execution of the event scanner"""
print("==================== BLOCKING UNTIL DKG FINALIZED ====================")
while (
coordinator_agent.get_ritual_status(RITUAL_ID)
!= Coordinator.RitualStatus.ACTIVE
):
for ursula in cohort:
# this is a testing hack to make the event scanner work,
# normally it's called by the reactor clock in a loop
ursula.ritual_tracker.task.run()
testerchain.time_travel(seconds=TIME_TRAVEL_INTERVAL)
# Ensure that all events processed, including EndRitual
for ursula in cohort:
ursula.ritual_tracker.task.run()
def check_finality(_):
"""Checks the finality of the DKG"""
print("==================== CHECKING DKG FINALITY ====================")
status = coordinator_agent.get_ritual_status(RITUAL_ID)
assert status == Coordinator.RitualStatus.ACTIVE
for ursula in cohort:
participant = coordinator_agent.get_participant(
RITUAL_ID, ursula.checksum_address, True
len(
coordinator_agent.get_participant(
ritual_id=ritual_id,
provider=ursula.checksum_address,
transcript=True,
).transcript
)
assert participant.transcript
assert participant.aggregated
> 0
), "no transcript found for ursula"
print(f"Ursula {ursula.checksum_address} has submitted a transcript")
last_scanned_block = REGISTRY.get_sample_value(
"ritual_events_last_scanned_block_number"
)
assert last_scanned_block > 0
def check_participant_pagination(_):
print("================ PARTICIPANT PAGINATION ================")
pagination_sizes = range(0, DKG_SIZE) # 0 means get all in one call
for page_size in pagination_sizes:
with patch.object(
coordinator_agent, "_get_page_size", return_value=page_size
):
ritual = coordinator_agent.get_ritual(RITUAL_ID, transcripts=True)
for i, participant in enumerate(ritual.participants):
assert participant.provider == cohort[i].checksum_address
assert participant.aggregated is True
assert participant.transcript
assert participant.decryption_request_static_key
def test_get_participants(coordinator_agent, cohort, ritual_id, dkg_size):
pagination_sizes = range(0, dkg_size) # 0 means get all in one call
for page_size in pagination_sizes:
with patch.object(coordinator_agent, "_get_page_size", return_value=page_size):
ritual = coordinator_agent.get_ritual(ritual_id, transcripts=True)
for i, participant in enumerate(ritual.participants):
assert participant.provider == cohort[i].checksum_address
assert participant.aggregated is True
assert participant.transcript
assert participant.decryption_request_static_key
assert len(ritual.participants) == DKG_SIZE
assert len(ritual.participants) == dkg_size
def check_encrypt(_):
"""Encrypts a message and returns the ciphertext and conditions"""
print("==================== DKG ENCRYPTION ====================")
encrypting_key = coordinator_agent.get_ritual_public_key(ritual_id=RITUAL_ID)
def test_encrypt(
coordinator_agent, condition, ritual_id, plaintext, testerchain, signer
):
print("==================== DKG ENCRYPTION ====================")
encrypting_key = coordinator_agent.get_ritual_public_key(ritual_id=ritual_id)
plaintext = plaintext.encode()
enrico = Enrico(encrypting_key=encrypting_key, signer=signer)
print(f"encrypting for DKG with key {bytes(encrypting_key).hex()}")
tmk = enrico.encrypt_for_dkg(plaintext=plaintext, conditions=condition)
assert tmk.ciphertext_header
# prepare message and conditions
plaintext = PLAINTEXT.encode()
# create Enrico
enrico = Enrico(encrypting_key=encrypting_key, signer=signer)
# encrypt
print(f"encrypting for DKG with key {bytes(encrypting_key).hex()}")
threshold_message_kit = enrico.encrypt_for_dkg(
plaintext=plaintext, conditions=condition
)
return threshold_message_kit
def check_unauthorized_decrypt(threshold_message_kit):
"""Attempts to decrypt a message before Enrico is authorized to use the ritual"""
print("======== DKG DECRYPTION UNAUTHORIZED ENCRYPTION ========")
# ritual_id, ciphertext, conditions are obtained from the side channel
bob.start_learning_loop(now=True)
with pytest.raises(
Ursula.NotEnoughUrsulas,
match=f"Encrypted data not authorized for ritual {RITUAL_ID}",
):
_ = bob.threshold_decrypt(
threshold_message_kit=threshold_message_kit,
)
# check prometheus metric for decryption requests
# since all running on the same machine - the value is not per-ursula but rather all
num_failures = REGISTRY.get_sample_value(
"threshold_decryption_num_failures_total"
)
assert len(cohort) == int(num_failures) # each ursula in cohort had a failure
print("========= UNAUTHORIZED DECRYPTION UNSUCCESSFUL =========")
return threshold_message_kit
def check_decrypt(threshold_message_kit):
"""Decrypts a message and checks that it matches the original plaintext"""
# authorize Enrico to encrypt for ritual
global_allow_list.authorize(
RITUAL_ID,
[signer.accounts[0]],
sender=accounts[initiator.transacting_power.account],
)
print("==================== DKG DECRYPTION ====================")
# ritual_id, ciphertext, conditions are obtained from the side channel
bob.start_learning_loop(now=True)
cleartext = bob.threshold_decrypt(
@pytest_twisted.inlineCallbacks
def test_unauthorized_decryption(bob, cohort, threshold_message_kit, ritual_id):
print("======== DKG DECRYPTION (UNAUTHORIZED) ========")
bob.start_learning_loop(now=True)
with pytest.raises(
Ursula.NotEnoughUrsulas,
match=f"Encrypted data not authorized for ritual {ritual_id}",
):
yield bob.threshold_decrypt(
threshold_message_kit=threshold_message_kit,
)
assert bytes(cleartext) == PLAINTEXT.encode()
# check prometheus metric for decryption requests
# since all running on the same machine - the value is not per-ursula but rather all
num_successes = REGISTRY.get_sample_value(
"threshold_decryption_num_successes_total"
)
# check prometheus metric for decryption requests
# since all running on the same machine - the value is not per-ursula but rather all
num_failures = REGISTRY.get_sample_value("threshold_decryption_num_failures_total")
assert len(cohort) == int(num_failures) # each ursula in cohort had a failure
yield
ritual = coordinator_agent.get_ritual(RITUAL_ID)
# at least a threshold of ursulas were successful (concurrency)
assert int(num_successes) >= ritual.threshold
# decrypt again (should only use cached values)
with patch.object(
coordinator_agent,
"get_provider_public_key",
side_effect=RuntimeError(
"should not be called to create validators; cache should be used"
),
):
# would like to but can't patch agent.get_ritual, since bob uses it
cleartext = bob.threshold_decrypt(
threshold_message_kit=threshold_message_kit,
)
assert bytes(cleartext) == PLAINTEXT.encode()
print("==================== DECRYPTION SUCCESSFUL ====================")
def check_decrypt_without_any_cached_values(threshold_message_kit, ritual_id, cohort, bob, coordinator_agent, plaintext):
print("==================== DKG DECRYPTION NO CACHE ====================")
original_validators = cohort[0].dkg_storage.get_validators(ritual_id)
return threshold_message_kit
for ursula in cohort:
ursula.dkg_storage.clear(ritual_id)
assert ursula.dkg_storage.get_validators(ritual_id) is None
assert ursula.dkg_storage.get_active_ritual(ritual_id) is None
def check_decrypt_without_any_cached_values(threshold_message_kit):
print("==================== DKG DECRYPTION NO CACHE ====================")
original_validators = cohort[0].dkg_storage.get_validators(RITUAL_ID)
bob.start_learning_loop(now=True)
cleartext = bob.threshold_decrypt(
threshold_message_kit=threshold_message_kit,
)
assert bytes(cleartext) == plaintext.encode()
for ursula in cohort:
ursula.dkg_storage.clear(RITUAL_ID)
assert ursula.dkg_storage.get_validators(RITUAL_ID) is None
assert ursula.dkg_storage.get_active_ritual(RITUAL_ID) is None
ritual = coordinator_agent.get_ritual(ritual_id)
num_used_ursulas = 0
for ursula_index, ursula in enumerate(cohort):
stored_ritual = ursula.dkg_storage.get_active_ritual(ritual_id)
if not stored_ritual:
# this ursula was not used for threshold decryption; skip
continue
assert stored_ritual == ritual
bob.start_learning_loop(now=True)
cleartext = bob.threshold_decrypt(
threshold_message_kit=threshold_message_kit,
)
assert bytes(cleartext) == PLAINTEXT.encode()
stored_validators = ursula.dkg_storage.get_validators(ritual_id)
num_used_ursulas += 1
for v_index, v in enumerate(stored_validators):
assert v.address == original_validators[v_index].address
assert v.public_key == original_validators[v_index].public_key
ritual = coordinator_agent.get_ritual(RITUAL_ID)
num_used_ursulas = 0
for ursula_index, ursula in enumerate(cohort):
stored_ritual = ursula.dkg_storage.get_active_ritual(RITUAL_ID)
if not stored_ritual:
# this ursula was not used for threshold decryption; skip
continue
assert stored_ritual == ritual
assert num_used_ursulas >= ritual.threshold
print("===================== DECRYPTION SUCCESSFUL =====================")
stored_validators = ursula.dkg_storage.get_validators(RITUAL_ID)
num_used_ursulas += 1
for v_index, v in enumerate(stored_validators):
assert v.address == original_validators[v_index].address
assert v.public_key == original_validators[v_index].public_key
assert num_used_ursulas >= ritual.threshold
print("===================== DECRYPTION SUCCESSFUL =====================")
@pytest_twisted.inlineCallbacks
def test_authorized_decryption(
bob,
global_allow_list,
accounts,
coordinator_agent,
threshold_message_kit,
signer,
initiator,
ritual_id,
plaintext,
):
print("==================== DKG DECRYPTION (AUTHORIZED) ====================")
# authorize Enrico to encrypt for ritual
global_allow_list.authorize(
ritual_id,
[signer.accounts[0]],
sender=accounts[initiator.transacting_power.account],
)
def error_handler(e):
"""Prints the error and raises it"""
print("==================== ERROR ====================")
print(e.getTraceback())
raise e
# ritual_id, ciphertext, conditions are obtained from the side channel
bob.start_learning_loop(now=True)
cleartext = yield bob.threshold_decrypt(
threshold_message_kit=threshold_message_kit,
)
assert bytes(cleartext) == plaintext.encode()
# order matters
callbacks = [
check_initialize,
block_until_dkg_finalized,
check_finality,
check_participant_pagination,
check_encrypt,
check_unauthorized_decrypt,
check_decrypt,
check_decrypt_without_any_cached_values,
]
# check prometheus metric for decryption requests
# since all running on the same machine - the value is not per-ursula but rather all
num_successes = REGISTRY.get_sample_value(
"threshold_decryption_num_successes_total"
)
d = deferToThread(initialize)
for callback in callbacks:
d.addCallback(callback)
d.addErrback(error_handler)
yield d
ritual = coordinator_agent.get_ritual(ritual_id)
# at least a threshold of ursulas were successful (concurrency)
assert int(num_successes) >= ritual.threshold
yield
def test_encryption_and_decryption_prometheus_metrics():
print("==================== METRICS ====================")
# check prometheus metric for decryption requests
# since all running on the same machine - the value is not per-ursula but rather all
num_decryption_failures = REGISTRY.get_sample_value(

View File

@ -1,12 +1,11 @@
import os
import pytest
import pytest_twisted
from eth_utils import keccak
from nucypher_core import SessionStaticSecret
from nucypher.blockchain.eth.agents import (
CoordinatorAgent,
)
from nucypher.blockchain.eth.agents import CoordinatorAgent
from nucypher.blockchain.eth.models import Coordinator
from nucypher.blockchain.eth.signers.software import Web3Signer
from nucypher.crypto.powers import TransactingPower
@ -112,18 +111,27 @@ def test_initiate_ritual(
assert ritual_dkg_key is None # no dkg key available until ritual is completed
def test_post_transcript(agent, transcripts, transacting_powers, testerchain):
@pytest_twisted.inlineCallbacks
def test_post_transcript(agent, transcripts, transacting_powers, testerchain, clock):
ritual_id = agent.number_of_rituals() - 1
txs = []
for i, transacting_power in enumerate(transacting_powers):
txhash = agent.post_transcript(
async_tx = agent.post_transcript(
ritual_id=ritual_id,
transcript=transcripts[i],
transacting_power=transacting_power,
)
txs.append(async_tx)
receipt = testerchain.wait_for_receipt(txhash)
testerchain.tx_machine.start()
while not all([tx.final for tx in txs]):
yield clock.advance(testerchain.tx_machine._task.interval)
testerchain.tx_machine.stop()
for i, atx in enumerate(txs):
post_transcript_events = (
agent.contract.events.TranscriptPosted().process_receipt(receipt)
agent.contract.events.TranscriptPosted().process_receipt(atx.receipt)
)
# assert len(post_transcript_events) == 1
event = post_transcript_events[0]
@ -142,6 +150,7 @@ def test_post_transcript(agent, transcripts, transacting_powers, testerchain):
assert ritual_dkg_key is None # no dkg key available until ritual is completed
@pytest_twisted.inlineCallbacks
def test_post_aggregation(
agent,
aggregated_transcript,
@ -149,24 +158,35 @@ def test_post_aggregation(
transacting_powers,
cohort,
testerchain,
clock,
):
testerchain.tx_machine.start()
ritual_id = agent.number_of_rituals() - 1
participant_public_keys = {}
txs = []
participant_public_key = SessionStaticSecret.random().public_key()
for i, transacting_power in enumerate(transacting_powers):
participant_public_key = SessionStaticSecret.random().public_key()
txhash = agent.post_aggregation(
async_tx = agent.post_aggregation(
ritual_id=ritual_id,
aggregated_transcript=aggregated_transcript,
public_key=dkg_public_key,
participant_public_key=participant_public_key,
transacting_power=transacting_power,
)
txs.append(async_tx)
testerchain.tx_machine.start()
while not all([tx.final for tx in txs]):
yield clock.advance(testerchain.tx_machine._task.interval)
testerchain.tx_machine.stop()
for i, atx in enumerate(txs):
participant_public_keys[cohort[i]] = participant_public_key
receipt = testerchain.wait_for_receipt(txhash)
post_aggregation_events = (
agent.contract.events.AggregationPosted().process_receipt(receipt)
agent.contract.events.AggregationPosted().process_receipt(atx.receipt)
)
# assert len(post_aggregation_events) == 1
assert len(post_aggregation_events) == 1
event = post_aggregation_events[0]
assert event["args"]["ritualId"] == ritual_id
assert event["args"]["aggregatedTranscriptDigest"] == keccak(

View File

@ -16,7 +16,6 @@ TransactingPower.lock_account = LOCK_FUNCTION
def test_character_transacting_power_signing(testerchain, test_registry):
# Pretend to be a character.
eth_address = testerchain.etherbase_account
signer = Character(
@ -28,9 +27,11 @@ def test_character_transacting_power_signing(testerchain, test_registry):
)
# Manually consume the power up
transacting_power = TransactingPower(password=INSECURE_DEVELOPMENT_PASSWORD,
signer=Web3Signer(testerchain.client),
account=eth_address)
transacting_power = TransactingPower(
password=INSECURE_DEVELOPMENT_PASSWORD,
signer=Web3Signer(testerchain.client),
account=eth_address,
)
signer._crypto_power.consume_power_up(transacting_power)
@ -40,68 +41,82 @@ def test_character_transacting_power_signing(testerchain, test_registry):
assert power == transacting_power
# Sign Message
data_to_sign = b'Premium Select Luxury Pencil Holder'
data_to_sign = b"Premium Select Luxury Pencil Holder"
signature = power.sign_message(message=data_to_sign)
is_verified = verify_eip_191(address=eth_address, message=data_to_sign, signature=signature)
is_verified = verify_eip_191(
address=eth_address, message=data_to_sign, signature=signature
)
assert is_verified is True
# Sign Transaction
transaction_dict = {'nonce': testerchain.client.w3.eth.get_transaction_count(eth_address),
'gasPrice': testerchain.client.w3.eth.gas_price,
'gas': 100000,
'from': eth_address,
'to': testerchain.unassigned_accounts[1],
'value': 1,
'data': b''}
transaction_dict = {
"nonce": testerchain.client.w3.eth.get_transaction_count(eth_address),
"gasPrice": testerchain.client.w3.eth.gas_price,
"gas": 100000,
"from": eth_address,
"to": testerchain.unassigned_accounts[1],
"value": 1,
"data": b"",
}
signed_transaction = power.sign_transaction(transaction_dict=transaction_dict)
# Demonstrate that the transaction is valid RLP encoded.
restored_transaction = Transaction.from_bytes(serialized_bytes=signed_transaction)
restored_transaction = Transaction.from_bytes(
serialized_bytes=signed_transaction.rawTransaction
)
restored_dict = restored_transaction.as_dict()
assert to_checksum_address(restored_dict['to']) == transaction_dict['to']
assert to_checksum_address(restored_dict["to"]) == transaction_dict["to"]
def test_transacting_power_sign_message(testerchain):
# Manually create a TransactingPower
eth_address = testerchain.etherbase_account
power = TransactingPower(password=INSECURE_DEVELOPMENT_PASSWORD,
signer=Web3Signer(testerchain.client),
account=eth_address)
power = TransactingPower(
password=INSECURE_DEVELOPMENT_PASSWORD,
signer=Web3Signer(testerchain.client),
account=eth_address,
)
# Manually unlock
power.unlock(password=INSECURE_DEVELOPMENT_PASSWORD)
# Sign
data_to_sign = b'Premium Select Luxury Pencil Holder'
data_to_sign = b"Premium Select Luxury Pencil Holder"
signature = power.sign_message(message=data_to_sign)
# Verify
is_verified = verify_eip_191(address=eth_address, message=data_to_sign, signature=signature)
is_verified = verify_eip_191(
address=eth_address, message=data_to_sign, signature=signature
)
assert is_verified is True
# Test invalid address/pubkey pair
is_verified = verify_eip_191(address=testerchain.client.accounts[1],
message=data_to_sign,
signature=signature)
is_verified = verify_eip_191(
address=testerchain.client.accounts[1],
message=data_to_sign,
signature=signature,
)
assert is_verified is False
def test_transacting_power_sign_transaction(testerchain):
eth_address = testerchain.unassigned_accounts[2]
power = TransactingPower(password=INSECURE_DEVELOPMENT_PASSWORD,
signer=Web3Signer(testerchain.client),
account=eth_address)
power = TransactingPower(
password=INSECURE_DEVELOPMENT_PASSWORD,
signer=Web3Signer(testerchain.client),
account=eth_address,
)
transaction_dict = {'nonce': testerchain.client.w3.eth.get_transaction_count(eth_address),
'gasPrice': testerchain.client.w3.eth.gas_price,
'gas': 100000,
'from': eth_address,
'to': testerchain.unassigned_accounts[1],
'value': 1,
'data': b''}
transaction_dict = {
"nonce": testerchain.client.w3.eth.get_transaction_count(eth_address),
"gasPrice": testerchain.client.w3.eth.gas_price,
"gas": 100000,
"from": eth_address,
"to": testerchain.unassigned_accounts[1],
"value": 1,
"data": b"",
}
# Sign
power.activate()
@ -109,13 +124,16 @@ def test_transacting_power_sign_transaction(testerchain):
# Demonstrate that the transaction is valid RLP encoded.
from eth_account._utils.legacy_transactions import Transaction
restored_transaction = Transaction.from_bytes(serialized_bytes=signed_transaction)
restored_transaction = Transaction.from_bytes(
serialized_bytes=signed_transaction.rawTransaction
)
restored_dict = restored_transaction.as_dict()
assert to_checksum_address(restored_dict['to']) == transaction_dict['to']
assert to_checksum_address(restored_dict["to"]) == transaction_dict["to"]
# Try signing with missing transaction fields
del transaction_dict['gas']
del transaction_dict['nonce']
del transaction_dict["gas"]
del transaction_dict["nonce"]
with pytest.raises(TypeError):
power.sign_transaction(transaction_dict=transaction_dict)
@ -127,21 +145,31 @@ def test_transacting_power_sign_agent_transaction(testerchain, coordinator_agent
g2_point
)
payload = {'chainId': int(testerchain.client.chain_id),
'nonce': testerchain.client.w3.eth.get_transaction_count(testerchain.etherbase_account),
'from': testerchain.etherbase_account,
'gasPrice': testerchain.client.gas_price,
'gas': 500_000}
payload = {
"chainId": int(testerchain.client.chain_id),
"nonce": testerchain.client.w3.eth.get_transaction_count(
testerchain.etherbase_account
),
"from": testerchain.etherbase_account,
"gasPrice": testerchain.client.gas_price,
"gas": 500_000,
}
unsigned_transaction = contract_function.build_transaction(payload)
# Sign with Transacting Power
transacting_power = TransactingPower(password=INSECURE_DEVELOPMENT_PASSWORD,
signer=Web3Signer(testerchain.client),
account=testerchain.etherbase_account)
signed_raw_transaction = transacting_power.sign_transaction(unsigned_transaction)
transacting_power = TransactingPower(
password=INSECURE_DEVELOPMENT_PASSWORD,
signer=Web3Signer(testerchain.client),
account=testerchain.etherbase_account,
)
signed_raw_transaction = transacting_power.sign_transaction(
unsigned_transaction
).rawTransaction
# Demonstrate that the transaction is valid RLP encoded.
restored_transaction = Transaction.from_bytes(serialized_bytes=signed_raw_transaction)
restored_transaction = Transaction.from_bytes(
serialized_bytes=signed_raw_transaction
)
restored_dict = restored_transaction.as_dict()
assert to_checksum_address(restored_dict['to']) == unsigned_transaction['to']
assert to_checksum_address(restored_dict["to"]) == unsigned_transaction["to"]

View File

@ -301,10 +301,11 @@ def test_registry(deployed_contracts, module_mocker):
@pytest.mark.usefixtures("test_registry")
@pytest.fixture(scope="module")
def testerchain(project) -> TesterBlockchain:
def testerchain(project, clock) -> TesterBlockchain:
# Extract the web3 provider containing EthereumTester from the ape project's chain manager
provider = project.chain_manager.provider.web3.provider
testerchain = TesterBlockchain(provider=provider)
testerchain.tx_machine._task.clock = clock
BlockchainInterfaceFactory.register_interface(interface=testerchain, force=True)
yield testerchain

View File

@ -3,12 +3,14 @@ import json
import os
import shutil
import tempfile
import time
from datetime import timedelta
from functools import partial
from pathlib import Path
from typing import Tuple
from unittest.mock import PropertyMock
import atxm
import maya
import pytest
from click.testing import CliRunner
@ -22,7 +24,6 @@ import tests
from nucypher.blockchain.eth.actors import Operator
from nucypher.blockchain.eth.interfaces import BlockchainInterfaceFactory
from nucypher.blockchain.eth.signers.software import KeystoreSigner
from nucypher.blockchain.eth.trackers.dkg import EventScannerTask
from nucypher.characters.lawful import Enrico, Ursula
from nucypher.config.characters import (
AliceConfiguration,
@ -45,6 +46,7 @@ from nucypher.policy.payment import SubscriptionManagerPayment
from nucypher.utilities.emitters import StdoutEmitter
from nucypher.utilities.logging import GlobalLoggerSettings, Logger
from nucypher.utilities.networking import LOOPBACK_ADDRESS
from nucypher.utilities.task import SimpleTask
from tests.constants import (
MIN_OPERATOR_SECONDS,
MOCK_CUSTOM_INSTALLATION_PATH,
@ -98,10 +100,11 @@ def tempfile_path():
@pytest.fixture(scope="module")
def temp_dir_path():
temp_dir = tempfile.TemporaryDirectory(prefix='nucypher-test-')
temp_dir = tempfile.TemporaryDirectory(prefix="nucypher-test-")
yield Path(temp_dir.name)
temp_dir.cleanup()
#
# Accounts
#
@ -118,6 +121,7 @@ def random_account():
def random_address(random_account):
return random_account.address
#
# Character Configurations
#
@ -224,7 +228,9 @@ def capsule_side_channel(enacted_policy):
self.plaintext_passthrough = False
def __call__(self):
message = "Welcome to flippering number {}.".format(len(self.messages)).encode()
message = "Welcome to flippering number {}.".format(
len(self.messages)
).encode()
message_kit = self.enrico.encrypt_for_pre(message)
self.messages.append((message_kit, self.enrico))
if self.plaintext_passthrough:
@ -250,6 +256,7 @@ def random_policy_label():
# Alice, Bob, and Ursula
#
@pytest.fixture(scope="module")
def alice(alice_test_config, ursulas, testerchain):
alice = alice_test_config.produce()
@ -290,6 +297,7 @@ def lonely_ursula_maker(ursula_test_config, testerchain):
del MOCK_KNOWN_URSULAS_CACHE[ursula.rest_interface.port]
for ursula in self._made:
ursula._finalize()
_maker = _PartialUrsulaMaker()
yield _maker
_maker.clean()
@ -304,7 +312,7 @@ def mock_registry_sources(module_mocker):
yield
@pytest.fixture(scope='module')
@pytest.fixture(scope="module")
def mock_testerchain() -> MockBlockchain:
BlockchainInterfaceFactory._interfaces = dict()
testerchain = MockBlockchain()
@ -314,9 +322,7 @@ def mock_testerchain() -> MockBlockchain:
@pytest.fixture()
def light_ursula(temp_dir_path, random_account, mocker):
mocker.patch.object(
KeystoreSigner, "_KeystoreSigner__get_signer", return_value=random_account
)
mocker.patch.object(KeystoreSigner, "_get_signer", return_value=random_account)
pre_payment_method = SubscriptionManagerPayment(
blockchain_endpoint=MOCK_ETH_PROVIDER_URI, domain=TEMPORARY_DOMAIN_NAME
)
@ -340,13 +346,13 @@ def light_ursula(temp_dir_path, random_account, mocker):
return ursula
@pytest.fixture(scope='module')
@pytest.fixture(scope="module")
def policy_rate():
rate = Web3.to_wei(21, 'gwei')
rate = Web3.to_wei(21, "gwei")
return rate
@pytest.fixture(scope='module')
@pytest.fixture(scope="module")
def policy_value(policy_rate):
value = policy_rate * MIN_OPERATOR_SECONDS
return value
@ -357,7 +363,7 @@ def policy_value(policy_rate):
#
@pytest.fixture(autouse=True, scope='function')
@pytest.fixture(autouse=True, scope="function")
def log_in_and_out_of_test(request):
test_name = request.node.name
module_name = request.module.__name__
@ -385,19 +391,22 @@ def fleet_of_highperf_mocked_ursulas(ursula_test_config, request, testerchain):
mock_cert_generation,
mock_remember_node,
mock_message_verification,
)
)
try:
quantity = request.param
except AttributeError:
quantity = 5000 # Bigass fleet by default; that's kinda the point.
staking_addresses = (to_checksum_address('0x' + os.urandom(20).hex()) for _ in range(5000))
operator_addresses = (to_checksum_address('0x' + os.urandom(20).hex()) for _ in range(5000))
staking_addresses = (
to_checksum_address("0x" + os.urandom(20).hex()) for _ in range(5000)
)
operator_addresses = (
to_checksum_address("0x" + os.urandom(20).hex()) for _ in range(5000)
)
with GlobalLoggerSettings.pause_all_logging_while():
with contextlib.ExitStack() as stack:
for mock in mocks:
stack.enter_context(mock)
@ -415,7 +424,9 @@ def fleet_of_highperf_mocked_ursulas(ursula_test_config, request, testerchain):
# It only needs to see whatever public info we can normally get via REST.
# Also sharing mutable Ursulas like that can lead to unpredictable results.
ursula.known_nodes.current_state._nodes = all_ursulas
ursula.known_nodes.current_state.checksum = b"This is a fleet state checksum..".hex()
ursula.known_nodes.current_state.checksum = (
b"This is a fleet state checksum..".hex()
)
yield _ursulas
@ -474,7 +485,8 @@ def highperf_mocked_bob(fleet_of_highperf_mocked_ursulas):
# CLI
#
@pytest.fixture(scope='function')
@pytest.fixture(scope="function")
def test_emitter(mocker):
# Note that this fixture does not capture console output.
# Whether the output is captured or not is controlled by
@ -482,13 +494,13 @@ def test_emitter(mocker):
return StdoutEmitter()
@pytest.fixture(scope='module')
@pytest.fixture(scope="module")
def click_runner():
runner = CliRunner()
yield runner
@pytest.fixture(scope='module')
@pytest.fixture(scope="module")
def nominal_configuration_fields():
config = UrsulaConfiguration(
dev_mode=True,
@ -500,7 +512,7 @@ def nominal_configuration_fields():
del config
@pytest.fixture(scope='module')
@pytest.fixture(scope="module")
def custom_filepath():
_custom_filepath = MOCK_CUSTOM_INSTALLATION_PATH
with contextlib.suppress(FileNotFoundError):
@ -510,7 +522,7 @@ def custom_filepath():
shutil.rmtree(_custom_filepath, ignore_errors=True)
@pytest.fixture(scope='module')
@pytest.fixture(scope="module")
def custom_filepath_2():
_custom_filepath = MOCK_CUSTOM_INSTALLATION_PATH_2
with contextlib.suppress(FileNotFoundError):
@ -522,9 +534,11 @@ def custom_filepath_2():
shutil.rmtree(_custom_filepath, ignore_errors=True)
@pytest.fixture(scope='module')
@pytest.fixture(scope="module")
def worker_configuration_file_location(custom_filepath) -> Path:
_configuration_file_location = MOCK_CUSTOM_INSTALLATION_PATH / UrsulaConfiguration.generate_filename()
_configuration_file_location = (
MOCK_CUSTOM_INSTALLATION_PATH / UrsulaConfiguration.generate_filename()
)
return _configuration_file_location
@ -537,15 +551,15 @@ def mock_teacher_nodes(mocker):
@pytest.fixture(autouse=True)
def disable_interactive_keystore_generation(mocker):
# Do not notify or confirm mnemonic seed words during tests normally
mocker.patch.object(Keystore, '_confirm_generate')
mocker.patch.object(Keystore, "_confirm_generate")
#
# Web Auth
#
@pytest.fixture(scope='module')
@pytest.fixture(scope="module")
def basic_auth_file(temp_dir_path):
basic_auth = Path(temp_dir_path) / 'htpasswd'
basic_auth = Path(temp_dir_path) / "htpasswd"
with basic_auth.open("w") as f:
# username: "admin", password: "admin"
f.write("admin:$apr1$hlEpWVoI$0qjykXrvdZ0yO2TnBggQO0\n")
@ -553,7 +567,7 @@ def basic_auth_file(temp_dir_path):
basic_auth.unlink()
@pytest.fixture(scope='module')
@pytest.fixture(scope="module")
def mock_rest_middleware():
return MockRestMiddleware(eth_endpoint=TEST_ETH_PROVIDER_URI)
@ -563,14 +577,14 @@ def mock_rest_middleware():
#
@pytest.fixture(scope='session')
@pytest.fixture(scope="session")
def conditions_test_data():
test_conditions = Path(tests.__file__).parent / "data" / "test_conditions.json"
with open(test_conditions, 'r') as file:
with open(test_conditions, "r") as file:
data = json.loads(file.read())
for name, condition in data.items():
if condition.get('chain'):
condition['chain'] = TESTERCHAIN_CHAIN_ID
if condition.get("chain"):
condition["chain"] = TESTERCHAIN_CHAIN_ID
return data
@ -627,7 +641,7 @@ def rpc_condition():
return condition
@pytest.fixture(scope='module')
@pytest.fixture(scope="module")
def valid_user_address_context():
return {
USER_ADDRESS_CONTEXT: {
@ -666,12 +680,12 @@ def valid_user_address_context():
}
@pytest.fixture(scope='module', autouse=True)
def control_time():
@pytest.fixture(scope="session", autouse=True)
def clock():
"""Distorts the space-time continuum. Use with caution."""
clock = Clock()
EventScannerTask.CLOCK = clock
EventScannerTask.INTERVAL = .1
clock.llamas = 0
SimpleTask.CLOCK = clock
SimpleTask.INTERVAL = 1
return clock
@ -689,7 +703,7 @@ def ursulas(testerchain, ursula_test_config, staking_providers):
know_each_other=True,
)
for u in _ursulas:
u.synchronous_query_timeout = .01 # We expect to never have to wait for content that is actually on-chain during tests.
u.synchronous_query_timeout = 0.01 # We expect to never have to wait for content that is actually on-chain during tests.
_ports_to_remove = [ursula.rest_interface.port for ursula in _ursulas]
yield _ursulas
@ -767,3 +781,14 @@ def mock_operator_aggregation_delay(module_mocker):
"nucypher.blockchain.eth.actors.Operator.AGGREGATION_SUBMISSION_MAX_DELAY",
PropertyMock(return_value=1),
)
@pytest.fixture(scope="session", autouse=True)
def mock_default_tracker_cache(session_mocker):
mock = session_mocker.patch.object(
atxm.state._State,
"_FILEPATH",
new_callable=session_mocker.PropertyMock,
)
mock.return_value = Path(tempfile.gettempdir()) / f".test-txs-{time.time()}.json"
return mock

View File

@ -112,13 +112,15 @@ def mock_stdin(mocker):
@pytest.fixture(scope="module")
def testerchain(mock_testerchain, module_mocker) -> MockBlockchain:
def testerchain(mock_testerchain, module_mocker, clock) -> MockBlockchain:
def always_use_mock(*a, **k):
return mock_testerchain
module_mocker.patch.object(
BlockchainInterfaceFactory, "get_interface", always_use_mock
)
mock_testerchain.tx_machine._task.clock = clock
return mock_testerchain

View File

@ -141,7 +141,7 @@ class MockCoordinatorAgent(MockContractAgent):
p.provider for p in ritual.participants
], # TODO This should not be
)
return self.blockchain.FAKE_TX_HASH
return self.blockchain.FAKE_ASYNX_TX
def post_aggregation(
self,
@ -172,10 +172,10 @@ class MockCoordinatorAgent(MockContractAgent):
ritual.aggregation_mismatch = True
# don't increment aggregations
# TODO Emit EndRitual here?
return self.blockchain.FAKE_TX_HASH
return self.blockchain.FAKE_ASYNX_TX
ritual.total_aggregations += 1
return self.blockchain.FAKE_TX_HASH
return self.blockchain.FAKE_ASYNX_TX
def set_provider_public_key(
self, public_key: FerveoPublicKey, transacting_power: TransactingPower

View File

@ -1,5 +1,6 @@
from typing import Union
from atxm.tx import FutureTx
from hexbytes import HexBytes
from nucypher.blockchain.eth.clients import EthereumTesterClient
@ -13,6 +14,19 @@ class MockBlockchain(TesterBlockchain):
FAKE_TX_HASH = HexBytes(b"FAKE29890FAKE8349804")
FAKE_ASYNX_TX = FutureTx(
id=1,
params={
"to": HexBytes(b"FAKEFAKEFAKEFAKEFAKEFAKEFAKEFAKEFAKEFAKE"),
"gas": 1,
"gasPrice": 1,
"value": 1,
"data": b"",
"nonce": 1,
},
_from=HexBytes(b"FAKEFAKEFAKEFAKEFAKEFAKEFAKEFAKEFAKE"),
)
FAKE_RECEIPT = {
"transactionHash": FAKE_TX_HASH,
"gasUsed": 1,

View File

@ -61,13 +61,15 @@ def mock_operator_bonding(session_mocker):
@pytest.fixture(scope="module")
def testerchain(mock_testerchain, module_mocker) -> MockBlockchain:
def testerchain(mock_testerchain, module_mocker, clock) -> MockBlockchain:
def always_use_mock(*a, **k):
return mock_testerchain
module_mocker.patch.object(
BlockchainInterfaceFactory, "get_interface", always_use_mock
)
mock_testerchain.tx_machine._task.clock = clock
return mock_testerchain

View File

@ -20,9 +20,6 @@ def test_operator_never_bonded(mocker, get_random_checksum_address):
tracker = OperatorBondedTracker(ursula=ursula)
try:
d = threads.deferToThread(tracker.start)
yield d
with pytest.raises(OperatorBondedTracker.OperatorNoLongerBonded):
d = threads.deferToThread(tracker.run)
yield d
@ -46,10 +43,6 @@ def test_operator_bonded_but_becomes_unbonded(mocker, get_random_checksum_addres
tracker = OperatorBondedTracker(ursula=ursula)
try:
d = threads.deferToThread(tracker.start)
yield d
# bonded
for i in range(1, 10):
d = threads.deferToThread(tracker.run)
yield d

View File

@ -13,9 +13,6 @@ def test_execution_of_collectors(mocker):
tracker = PrometheusMetricsTracker(collectors=collectors, interval=45)
try:
d = threads.deferToThread(tracker.start)
yield d
d = threads.deferToThread(tracker.run)
yield d

View File

@ -1,7 +1,4 @@
from unittest.mock import patch
import pytest
from hexbytes import HexBytes
from nucypher.blockchain.eth.agents import CoordinatorAgent
from nucypher.blockchain.eth.models import Coordinator
@ -9,6 +6,7 @@ from nucypher.blockchain.eth.signers.software import Web3Signer
from nucypher.crypto.powers import RitualisticPower, TransactingPower
from tests.constants import MOCK_ETH_PROVIDER_URI
from tests.mock.coordinator import MockCoordinatorAgent
from tests.mock.interfaces import MockBlockchain
@pytest.fixture(scope="module")
@ -22,8 +20,8 @@ def agent(mock_contract_agency, ursulas) -> MockCoordinatorAgent:
if ursula.checksum_address == provider:
return ursula.public_keys(RitualisticPower)
coordinator_agent.post_transcript = lambda *args, **kwargs: HexBytes("deadbeef1")
coordinator_agent.post_aggregation = lambda *args, **kwargs: HexBytes("deadbeef2")
coordinator_agent.post_transcript = lambda *a, **kw: MockBlockchain.FAKE_ASYNX_TX
coordinator_agent.post_aggregation = lambda *a, **kw: MockBlockchain.FAKE_ASYNX_TX
coordinator_agent.get_provider_public_key = mock_get_provider_public_key
return coordinator_agent
@ -132,86 +130,33 @@ def test_perform_round_1(
]
for state in non_application_states:
agent.get_ritual_status = lambda *args, **kwargs: state
tx_hash = ursula.perform_round_1(
tx = ursula.perform_round_1(
ritual_id=0, authority=random_address, participants=cohort, timestamp=0
)
assert tx_hash is None # no execution performed
assert tx is None # no execution performed
# set correct state
agent.get_ritual_status = (
lambda *args, **kwargs: Coordinator.RitualStatus.DKG_AWAITING_TRANSCRIPTS
)
original_tx_hash = ursula.perform_round_1(
ursula.perform_round_1(
ritual_id=0, authority=random_address, participants=cohort, timestamp=0
)
assert original_tx_hash is not None
# ensure tx hash is stored
assert ursula.dkg_storage.get_transcript_txhash(ritual_id=0) == original_tx_hash
# ensure tx is tracked
assert len(ursula.ritual_tracker.active_rituals) == 1
pid01 = ursula._phase_id(0, 1)
assert ursula.ritual_tracker.active_rituals[pid01]
# try again
tx_hash = ursula.perform_round_1(
ursula.perform_round_1(
ritual_id=0, authority=random_address, participants=cohort, timestamp=0
)
assert tx_hash is None # no execution since pending tx already present
# pending tx gets mined and removed from storage - receipt status is 1
mock_receipt = {"status": 1}
with patch.object(
agent.blockchain.client, "get_transaction_receipt", return_value=mock_receipt
):
tx_hash = ursula.perform_round_1(
ritual_id=0, authority=random_address, participants=cohort, timestamp=0
)
# no execution since pending tx was present and determined to be mined
assert tx_hash is None
# tx hash removed since tx receipt was obtained - outcome moving
# forward is represented on contract
assert ursula.dkg_storage.get_transcript_txhash(ritual_id=0) is None
# reset tx hash
ursula.dkg_storage.store_transcript_txhash(ritual_id=0, txhash=original_tx_hash)
# pending tx gets mined and removed from storage - receipt
# status is 0 i.e. evm revert - so use contract state which indicates
# to submit transcript
mock_receipt = {"status": 0}
with patch.object(
agent.blockchain.client, "get_transaction_receipt", return_value=mock_receipt
):
with patch.object(
agent, "post_transcript", lambda *args, **kwargs: HexBytes("A1B1")
):
mock_tx_hash = ursula.perform_round_1(
ritual_id=0, authority=random_address, participants=cohort, timestamp=0
)
# execution occurs because evm revert causes execution to be retried
assert mock_tx_hash == HexBytes("A1B1")
# tx hash changed since original tx hash removed due to status being 0
# and new tx hash added
# forward is represented on contract
assert ursula.dkg_storage.get_transcript_txhash(ritual_id=0) == mock_tx_hash
assert (
ursula.dkg_storage.get_transcript_txhash(ritual_id=0)
!= original_tx_hash
)
# reset tx hash
ursula.dkg_storage.store_transcript_txhash(ritual_id=0, txhash=original_tx_hash)
# don't clear if tx hash mismatched
assert ursula.dkg_storage.get_transcript_txhash(ritual_id=0) is not None
assert not ursula.dkg_storage.clear_transcript_txhash(
ritual_id=0, txhash=HexBytes("abcd")
)
assert ursula.dkg_storage.get_transcript_txhash(ritual_id=0) is not None
# clear tx hash
assert ursula.dkg_storage.clear_transcript_txhash(
ritual_id=0, txhash=original_tx_hash
)
assert ursula.dkg_storage.get_transcript_txhash(ritual_id=0) is None
assert len(ursula.ritual_tracker.active_rituals) == 1
assert ursula.ritual_tracker.active_rituals[pid01]
# participant already posted transcript
participant = agent.get_participant(
@ -220,18 +165,21 @@ def test_perform_round_1(
participant.transcript = bytes(random_transcript)
# try submitting again
tx_hash = ursula.perform_round_1(
ursula.perform_round_1(
ritual_id=0, authority=random_address, participants=cohort, timestamp=0
)
# no execution performed since already posted transcript
assert tx_hash is None
assert len(ursula.ritual_tracker.active_rituals) == 1
assert ursula.ritual_tracker.active_rituals[pid01]
# participant no longer already posted aggregated transcript
participant.transcript = bytes()
tx_hash = ursula.perform_round_1(
ursula.perform_round_1(
ritual_id=0, authority=random_address, participants=cohort, timestamp=0
)
assert tx_hash is not None # execution occurs
assert len(ursula.ritual_tracker.active_rituals) == 1
assert ursula.ritual_tracker.active_rituals[pid01]
def test_perform_round_2(
@ -286,8 +234,11 @@ def test_perform_round_2(
]
for state in non_application_states:
agent.get_ritual_status = lambda *args, **kwargs: state
tx_hash = ursula.perform_round_2(ritual_id=0, timestamp=0)
assert tx_hash is None # no execution performed
ursula.perform_round_2(ritual_id=0, timestamp=0)
pid02 = ursula._phase_id(0, 2)
assert ursula.ritual_tracker.active_rituals[pid02]
assert len(ursula.ritual_tracker.active_rituals) == 1
# set correct state
agent.get_ritual_status = (
@ -295,81 +246,20 @@ def test_perform_round_2(
)
mocker.patch("nucypher.crypto.ferveo.dkg.verify_aggregate")
original_tx_hash = ursula.perform_round_2(ritual_id=0, timestamp=0)
assert original_tx_hash is not None
ursula.perform_round_2(ritual_id=0, timestamp=0)
# check tx hash
assert ursula.dkg_storage.get_aggregation_txhash(ritual_id=0) == original_tx_hash
# check tx hash tracking
assert len(ursula.ritual_tracker.active_rituals) == 2
# try again
tx_hash = ursula.perform_round_2(ritual_id=0, timestamp=0)
assert tx_hash is None # no execution since pending tx already present
ursula.perform_round_2(ritual_id=0, timestamp=0)
# pending tx gets mined and removed from storage - receipt status is 1
mock_receipt = {"status": 1}
with patch.object(
agent.blockchain.client, "get_transaction_receipt", return_value=mock_receipt
):
tx_hash = ursula.perform_round_2(ritual_id=0, timestamp=0)
# no execution since pending tx was present and determined to be mined
assert tx_hash is None
# tx hash removed since tx receipt was obtained - outcome moving
# forward is represented on contract
assert ursula.dkg_storage.get_aggregation_txhash(ritual_id=0) is None
# reset tx hash
ursula.dkg_storage.store_aggregation_txhash(ritual_id=0, txhash=original_tx_hash)
# pending tx gets mined and removed from storage - receipt
# status is 0 i.e. evm revert - so use contract state which indicates
# to submit transcript
mock_receipt = {"status": 0}
with patch.object(
agent.blockchain.client, "get_transaction_receipt", return_value=mock_receipt
):
with patch.object(
agent, "post_aggregation", lambda *args, **kwargs: HexBytes("A1B1")
):
mock_tx_hash = ursula.perform_round_2(ritual_id=0, timestamp=0)
# execution occurs because evm revert causes execution to be retried
assert mock_tx_hash == HexBytes("A1B1")
# tx hash changed since original tx hash removed due to status being 0
# and new tx hash added
# forward is represented on contract
assert (
ursula.dkg_storage.get_aggregation_txhash(ritual_id=0) == mock_tx_hash
)
assert (
ursula.dkg_storage.get_aggregation_txhash(ritual_id=0)
!= original_tx_hash
)
# reset tx hash
ursula.dkg_storage.store_aggregation_txhash(ritual_id=0, txhash=original_tx_hash)
# don't clear if tx hash mismatched
assert not ursula.dkg_storage.clear_aggregated_txhash(
ritual_id=0, txhash=HexBytes("1234")
)
assert ursula.dkg_storage.get_aggregation_txhash(ritual_id=0) is not None
# clear tx hash
assert ursula.dkg_storage.clear_aggregated_txhash(
ritual_id=0, txhash=original_tx_hash
)
assert ursula.dkg_storage.get_aggregation_txhash(ritual_id=0) is None
# participant already posted aggregated transcript
participant = agent.get_participant(
ritual_id=0, provider=ursula.checksum_address, transcript=False
)
participant.aggregated = True
# try submitting again
tx_hash = ursula.perform_round_2(ritual_id=0, timestamp=0)
assert tx_hash is None # no execution performed
ursula.perform_round_2(ritual_id=0, timestamp=0)
# participant no longer already posted aggregated transcript
participant.aggregated = False
tx_hash = ursula.perform_round_2(ritual_id=0, timestamp=0)
assert tx_hash is not None # execution occurs
ursula.perform_round_2(ritual_id=0, timestamp=0)

View File

@ -90,6 +90,7 @@ class TesterBlockchain(BlockchainInterface):
*args,
**kwargs,
)
self.log = Logger("test-blockchain")
self.connect()