promotes phase id to a NamedTuple.

pull/3475/head
KPrasch 2024-02-15 14:44:12 +01:00 committed by derekpierre
parent bf2dbbc2c0
commit c16f93707a
No known key found for this signature in database
5 changed files with 23 additions and 23 deletions

View File

@ -3,7 +3,7 @@ import random
import time
from collections import defaultdict
from decimal import Decimal
from typing import DefaultDict, Dict, List, Optional, Set, Tuple, Union
from typing import DefaultDict, Dict, List, Optional, Set, Union
import maya
from atxm.tx import AsyncTx
@ -56,7 +56,7 @@ from nucypher.datastore.dkg import DKGStorage
from nucypher.policy.conditions.evm import _CONDITION_CHAINS
from nucypher.policy.conditions.utils import evaluate_condition_lingo
from nucypher.policy.payment import ContractPayment
from nucypher.types import PhaseId, RitualId
from nucypher.types import PhaseId
from nucypher.utilities.emitters import StdoutEmitter
from nucypher.utilities.logging import Logger
@ -352,7 +352,7 @@ class Operator(BaseActor):
transcript=transcript,
transacting_power=self.transacting_power,
)
identifier = self._phase_id(ritual_id, PHASE1)
identifier = PhaseId(ritual_id, PHASE1)
self.ritual_tracker.active_rituals[identifier] = async_tx
return async_tx
@ -374,7 +374,7 @@ class Operator(BaseActor):
participant_public_key=participant_public_key,
transacting_power=self.transacting_power,
)
identifier = self._phase_id(ritual_id=ritual_id, phase=PHASE2)
identifier = PhaseId(ritual_id=ritual_id, phase=PHASE2)
self.ritual_tracker.active_rituals[identifier] = async_tx
return async_tx
@ -407,10 +407,6 @@ class Operator(BaseActor):
return True
@staticmethod
def _phase_id(ritual_id: int, phase: int) -> Tuple[RitualId, PhaseId]:
return RitualId(ritual_id), PhaseId(phase)
def perform_round_1(
self,
ritual_id: int,
@ -461,9 +457,7 @@ class Operator(BaseActor):
return
# check if there is already pending tx for this ritual + round combination
async_tx = self.ritual_tracker.active_rituals.get(
self._phase_id(ritual_id, PHASE1)
)
async_tx = self.ritual_tracker.active_rituals.get(PhaseId(ritual_id, PHASE1))
if async_tx:
self.log.info(
f"Active ritual in progress: {self.transacting_power.account} has submitted tx "
@ -553,7 +547,7 @@ class Operator(BaseActor):
# check if there is a pending tx for this ritual + round combination
async_tx = self.ritual_tracker.active_rituals.get(
self._phase_id(ritual_id=ritual_id, phase=PHASE2)
PhaseId(ritual_id=ritual_id, phase=PHASE2)
)
if async_tx:
self.log.info(

View File

@ -13,10 +13,10 @@ from nucypher_core.ferveo import (
FerveoPublicKey,
)
from nucypher.types import PhaseId
from nucypher.types import PhaseNumber
PHASE1 = PhaseId(1)
PHASE2 = PhaseId(2)
PHASE1 = PhaseNumber(1)
PHASE2 = PhaseNumber(2)
@dataclass

View File

@ -11,7 +11,7 @@ from web3.datastructures import AttributeDict
from nucypher.blockchain.eth.models import Coordinator
from nucypher.policy.conditions.utils import camel_case_to_snake
from nucypher.types import PhaseId, RitualId
from nucypher.types import PhaseId
from nucypher.utilities.cache import TTLCache
from nucypher.utilities.events import EventScanner, JSONifiedState
from nucypher.utilities.logging import Logger
@ -123,7 +123,7 @@ class ActiveRitualTracker:
self.contract.events.EndRitual,
]
self.__phase_txs: Dict[Tuple[RitualId, PhaseId], FutureTx] = {}
self.__phase_txs: Dict[PhaseId, FutureTx] = {}
# TODO: Remove the default JSON-RPC retry middleware
# as it correctly cannot handle eth_getLogs block range throttle down.
@ -168,7 +168,7 @@ class ActiveRitualTracker:
return self.coordinator_agent.contract
@property
def active_rituals(self) -> Dict[Tuple[RitualId, PhaseId], AsyncTx]:
def active_rituals(self) -> Dict[PhaseId, AsyncTx]:
return self.__phase_txs
# TODO: should sample_window_size be additionally configurable/chain-dependent?

View File

@ -1,4 +1,4 @@
from typing import NewType, TypeVar
from typing import NamedTuple, NewType, TypeVar
ERC20UNits = NewType("ERC20UNits", int)
NuNits = NewType("NuNits", ERC20UNits)
@ -7,4 +7,9 @@ TuNits = NewType("TuNits", ERC20UNits)
Agent = TypeVar("Agent", bound="agents.EthereumContractAgent") # noqa: F821
RitualId = int
PhaseId = int
PhaseNumber = int
class PhaseId(NamedTuple):
ritual_id: RitualId
phase: PhaseNumber

View File

@ -4,6 +4,7 @@ from nucypher.blockchain.eth.agents import CoordinatorAgent
from nucypher.blockchain.eth.models import PHASE1, PHASE2, Coordinator
from nucypher.blockchain.eth.signers.software import Web3Signer
from nucypher.crypto.powers import RitualisticPower, TransactingPower
from nucypher.types import PhaseId
from tests.constants import MOCK_ETH_PROVIDER_URI
from tests.mock.coordinator import MockCoordinatorAgent
from tests.mock.interfaces import MockBlockchain
@ -148,7 +149,7 @@ def test_perform_round_1(
assert async_tx
assert len(ursula.ritual_tracker.active_rituals) == 1
pid01 = ursula._phase_id(0, 1)
pid01 = PhaseId(ritual_id=0, phase=PHASE1)
assert ursula.ritual_tracker.active_rituals[pid01]
# try again
@ -241,7 +242,7 @@ def test_perform_round_2(
ursula.perform_round_2(ritual_id=0, timestamp=0)
assert len(ursula.ritual_tracker.active_rituals) == 1
pid01 = ursula._phase_id(ritual_id=0, phase=PHASE1)
pid01 = PhaseId(ritual_id=0, phase=PHASE1)
assert ursula.ritual_tracker.active_rituals[pid01]
# set correct state
@ -254,7 +255,7 @@ def test_perform_round_2(
# check async tx tracking
assert len(ursula.ritual_tracker.active_rituals) == 2
pid02 = ursula._phase_id(ritual_id=0, phase=PHASE2)
pid02 = PhaseId(ritual_id=0, phase=PHASE2)
assert ursula.ritual_tracker.active_rituals[pid02]
# trying again yields same tx