From c16f93707ac6425b77f2540ada559275a68f7139 Mon Sep 17 00:00:00 2001 From: KPrasch Date: Thu, 15 Feb 2024 14:44:12 +0100 Subject: [PATCH] promotes phase id to a NamedTuple. --- nucypher/blockchain/eth/actors.py | 18 ++++++------------ nucypher/blockchain/eth/models.py | 6 +++--- nucypher/blockchain/eth/trackers/dkg.py | 6 +++--- nucypher/types.py | 9 +++++++-- tests/unit/test_ritualist.py | 7 ++++--- 5 files changed, 23 insertions(+), 23 deletions(-) diff --git a/nucypher/blockchain/eth/actors.py b/nucypher/blockchain/eth/actors.py index aa27775fd..0a896d50e 100644 --- a/nucypher/blockchain/eth/actors.py +++ b/nucypher/blockchain/eth/actors.py @@ -3,7 +3,7 @@ import random import time from collections import defaultdict from decimal import Decimal -from typing import DefaultDict, Dict, List, Optional, Set, Tuple, Union +from typing import DefaultDict, Dict, List, Optional, Set, Union import maya from atxm.tx import AsyncTx @@ -56,7 +56,7 @@ from nucypher.datastore.dkg import DKGStorage from nucypher.policy.conditions.evm import _CONDITION_CHAINS from nucypher.policy.conditions.utils import evaluate_condition_lingo from nucypher.policy.payment import ContractPayment -from nucypher.types import PhaseId, RitualId +from nucypher.types import PhaseId from nucypher.utilities.emitters import StdoutEmitter from nucypher.utilities.logging import Logger @@ -352,7 +352,7 @@ class Operator(BaseActor): transcript=transcript, transacting_power=self.transacting_power, ) - identifier = self._phase_id(ritual_id, PHASE1) + identifier = PhaseId(ritual_id, PHASE1) self.ritual_tracker.active_rituals[identifier] = async_tx return async_tx @@ -374,7 +374,7 @@ class Operator(BaseActor): participant_public_key=participant_public_key, transacting_power=self.transacting_power, ) - identifier = self._phase_id(ritual_id=ritual_id, phase=PHASE2) + identifier = PhaseId(ritual_id=ritual_id, phase=PHASE2) self.ritual_tracker.active_rituals[identifier] = async_tx return async_tx @@ -407,10 +407,6 @@ class Operator(BaseActor): return True - @staticmethod - def _phase_id(ritual_id: int, phase: int) -> Tuple[RitualId, PhaseId]: - return RitualId(ritual_id), PhaseId(phase) - def perform_round_1( self, ritual_id: int, @@ -461,9 +457,7 @@ class Operator(BaseActor): return # check if there is already pending tx for this ritual + round combination - async_tx = self.ritual_tracker.active_rituals.get( - self._phase_id(ritual_id, PHASE1) - ) + async_tx = self.ritual_tracker.active_rituals.get(PhaseId(ritual_id, PHASE1)) if async_tx: self.log.info( f"Active ritual in progress: {self.transacting_power.account} has submitted tx " @@ -553,7 +547,7 @@ class Operator(BaseActor): # check if there is a pending tx for this ritual + round combination async_tx = self.ritual_tracker.active_rituals.get( - self._phase_id(ritual_id=ritual_id, phase=PHASE2) + PhaseId(ritual_id=ritual_id, phase=PHASE2) ) if async_tx: self.log.info( diff --git a/nucypher/blockchain/eth/models.py b/nucypher/blockchain/eth/models.py index c933dd32f..a3d41d2af 100644 --- a/nucypher/blockchain/eth/models.py +++ b/nucypher/blockchain/eth/models.py @@ -13,10 +13,10 @@ from nucypher_core.ferveo import ( FerveoPublicKey, ) -from nucypher.types import PhaseId +from nucypher.types import PhaseNumber -PHASE1 = PhaseId(1) -PHASE2 = PhaseId(2) +PHASE1 = PhaseNumber(1) +PHASE2 = PhaseNumber(2) @dataclass diff --git a/nucypher/blockchain/eth/trackers/dkg.py b/nucypher/blockchain/eth/trackers/dkg.py index c48a0b2a4..6fe1e1e17 100644 --- a/nucypher/blockchain/eth/trackers/dkg.py +++ b/nucypher/blockchain/eth/trackers/dkg.py @@ -11,7 +11,7 @@ from web3.datastructures import AttributeDict from nucypher.blockchain.eth.models import Coordinator from nucypher.policy.conditions.utils import camel_case_to_snake -from nucypher.types import PhaseId, RitualId +from nucypher.types import PhaseId from nucypher.utilities.cache import TTLCache from nucypher.utilities.events import EventScanner, JSONifiedState from nucypher.utilities.logging import Logger @@ -123,7 +123,7 @@ class ActiveRitualTracker: self.contract.events.EndRitual, ] - self.__phase_txs: Dict[Tuple[RitualId, PhaseId], FutureTx] = {} + self.__phase_txs: Dict[PhaseId, FutureTx] = {} # TODO: Remove the default JSON-RPC retry middleware # as it correctly cannot handle eth_getLogs block range throttle down. @@ -168,7 +168,7 @@ class ActiveRitualTracker: return self.coordinator_agent.contract @property - def active_rituals(self) -> Dict[Tuple[RitualId, PhaseId], AsyncTx]: + def active_rituals(self) -> Dict[PhaseId, AsyncTx]: return self.__phase_txs # TODO: should sample_window_size be additionally configurable/chain-dependent? diff --git a/nucypher/types.py b/nucypher/types.py index d2fbcf0d6..f6a6ff825 100644 --- a/nucypher/types.py +++ b/nucypher/types.py @@ -1,4 +1,4 @@ -from typing import NewType, TypeVar +from typing import NamedTuple, NewType, TypeVar ERC20UNits = NewType("ERC20UNits", int) NuNits = NewType("NuNits", ERC20UNits) @@ -7,4 +7,9 @@ TuNits = NewType("TuNits", ERC20UNits) Agent = TypeVar("Agent", bound="agents.EthereumContractAgent") # noqa: F821 RitualId = int -PhaseId = int +PhaseNumber = int + + +class PhaseId(NamedTuple): + ritual_id: RitualId + phase: PhaseNumber diff --git a/tests/unit/test_ritualist.py b/tests/unit/test_ritualist.py index 4918d94ce..bafe5356e 100644 --- a/tests/unit/test_ritualist.py +++ b/tests/unit/test_ritualist.py @@ -4,6 +4,7 @@ from nucypher.blockchain.eth.agents import CoordinatorAgent from nucypher.blockchain.eth.models import PHASE1, PHASE2, Coordinator from nucypher.blockchain.eth.signers.software import Web3Signer from nucypher.crypto.powers import RitualisticPower, TransactingPower +from nucypher.types import PhaseId from tests.constants import MOCK_ETH_PROVIDER_URI from tests.mock.coordinator import MockCoordinatorAgent from tests.mock.interfaces import MockBlockchain @@ -148,7 +149,7 @@ def test_perform_round_1( assert async_tx assert len(ursula.ritual_tracker.active_rituals) == 1 - pid01 = ursula._phase_id(0, 1) + pid01 = PhaseId(ritual_id=0, phase=PHASE1) assert ursula.ritual_tracker.active_rituals[pid01] # try again @@ -241,7 +242,7 @@ def test_perform_round_2( ursula.perform_round_2(ritual_id=0, timestamp=0) assert len(ursula.ritual_tracker.active_rituals) == 1 - pid01 = ursula._phase_id(ritual_id=0, phase=PHASE1) + pid01 = PhaseId(ritual_id=0, phase=PHASE1) assert ursula.ritual_tracker.active_rituals[pid01] # set correct state @@ -254,7 +255,7 @@ def test_perform_round_2( # check async tx tracking assert len(ursula.ritual_tracker.active_rituals) == 2 - pid02 = ursula._phase_id(ritual_id=0, phase=PHASE2) + pid02 = PhaseId(ritual_id=0, phase=PHASE2) assert ursula.ritual_tracker.active_rituals[pid02] # trying again yields same tx