mirror of https://github.com/nucypher/nucypher.git
Don't store async_txs on the ActiveRitualTracker since it only cares about events; utilize existing DKG storage for operator. This will also allow for easier logic for clearing when no longer needed by the Operator.
Modify DKGStorage to store async txs. Update tests accordingly.pull/3476/head
parent
2be52480c9
commit
dcb93abed8
|
@ -419,7 +419,9 @@ class Operator(BaseActor):
|
|||
transacting_power=self.transacting_power,
|
||||
async_tx_hooks=async_tx_hooks,
|
||||
)
|
||||
self.ritual_tracker.active_rituals[identifier] = async_tx
|
||||
self.dkg_storage.store_ritual_phase_async_tx(
|
||||
phase_id=identifier, async_tx=async_tx
|
||||
)
|
||||
return async_tx
|
||||
|
||||
def publish_aggregated_transcript(
|
||||
|
@ -445,8 +447,9 @@ class Operator(BaseActor):
|
|||
transacting_power=self.transacting_power,
|
||||
async_tx_hooks=async_tx_hooks,
|
||||
)
|
||||
|
||||
self.ritual_tracker.active_rituals[identifier] = async_tx
|
||||
self.dkg_storage.store_ritual_phase_async_tx(
|
||||
phase_id=identifier, async_tx=async_tx
|
||||
)
|
||||
return async_tx
|
||||
|
||||
def _is_phase_1_action_required(self, ritual_id: int) -> bool:
|
||||
|
@ -521,8 +524,10 @@ class Operator(BaseActor):
|
|||
)
|
||||
return
|
||||
|
||||
# check if there is already pending tx for this ritual + round combination
|
||||
async_tx = self.ritual_tracker.active_rituals.get(PhaseId(ritual_id, PHASE1))
|
||||
# check if there is already pending tx for this ritual + round combination
|
||||
async_tx = self.dkg_storage.get_ritual_phase_async_tx(
|
||||
phase_id=PhaseId(ritual_id, PHASE1)
|
||||
)
|
||||
if async_tx:
|
||||
self.log.info(
|
||||
f"Active ritual in progress: {self.transacting_power.account} has submitted tx "
|
||||
|
@ -611,12 +616,12 @@ class Operator(BaseActor):
|
|||
return
|
||||
|
||||
# check if there is a pending tx for this ritual + round combination
|
||||
async_tx = self.ritual_tracker.active_rituals.get(
|
||||
PhaseId(ritual_id=ritual_id, phase=PHASE2)
|
||||
async_tx = self.dkg_storage.get_ritual_phase_async_tx(
|
||||
phase_id=PhaseId(ritual_id, PHASE2)
|
||||
)
|
||||
if async_tx:
|
||||
self.log.info(
|
||||
f"Active ritual in progress Node {self.transacting_power.account} has submitted tx"
|
||||
f"Active ritual in progress: {self.transacting_power.account} has submitted tx"
|
||||
f"for ritual #{ritual_id}, phase #{PHASE2} (final: {async_tx.final})."
|
||||
)
|
||||
return async_tx
|
||||
|
|
|
@ -1,17 +1,15 @@
|
|||
import datetime
|
||||
import os
|
||||
import time
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import maya
|
||||
from atxm.tx import AsyncTx, FutureTx
|
||||
from prometheus_client import REGISTRY, Gauge
|
||||
from twisted.internet import threads
|
||||
from web3.datastructures import AttributeDict
|
||||
|
||||
from nucypher.blockchain.eth.models import Coordinator
|
||||
from nucypher.policy.conditions.utils import camel_case_to_snake
|
||||
from nucypher.types import PhaseId
|
||||
from nucypher.utilities.cache import TTLCache
|
||||
from nucypher.utilities.events import EventScanner, JSONifiedState
|
||||
from nucypher.utilities.logging import Logger
|
||||
|
@ -123,8 +121,6 @@ class ActiveRitualTracker:
|
|||
self.contract.events.EndRitual,
|
||||
]
|
||||
|
||||
self.__phase_txs: Dict[PhaseId, FutureTx] = {}
|
||||
|
||||
# TODO: Remove the default JSON-RPC retry middleware
|
||||
# as it correctly cannot handle eth_getLogs block range throttle down.
|
||||
# self.web3.middleware_onion.remove(http_retry_request_middleware)
|
||||
|
@ -167,10 +163,6 @@ class ActiveRitualTracker:
|
|||
def contract(self):
|
||||
return self.coordinator_agent.contract
|
||||
|
||||
@property
|
||||
def active_rituals(self) -> Dict[PhaseId, AsyncTx]:
|
||||
return self.__phase_txs
|
||||
|
||||
# TODO: should sample_window_size be additionally configurable/chain-dependent?
|
||||
def _get_first_scan_start_block_number(self, sample_window_size: int = 100) -> int:
|
||||
"""
|
||||
|
|
|
@ -1,30 +1,31 @@
|
|||
from collections import defaultdict
|
||||
from typing import List, Optional
|
||||
|
||||
from hexbytes import HexBytes
|
||||
from atxm.tx import AsyncTx
|
||||
from nucypher_core.ferveo import (
|
||||
Validator,
|
||||
)
|
||||
|
||||
from nucypher.blockchain.eth.models import Coordinator
|
||||
from nucypher.blockchain.eth.models import PHASE1, Coordinator
|
||||
from nucypher.types import PhaseId
|
||||
|
||||
|
||||
class DKGStorage:
|
||||
"""A simple in-memory storage for DKG data"""
|
||||
|
||||
# round 1
|
||||
KEY_TRANSCRIPT_TXS = "transcript_tx_hashes"
|
||||
KEY_VALIDATORS = "validators"
|
||||
_KEY_PHASE_1_TXS = "phase_1_txs"
|
||||
_KEY_VALIDATORS = "validators"
|
||||
# round 2
|
||||
KEY_AGGREGATED_TXS = "aggregation_tx_hashes"
|
||||
_KEY_PHASE_2_TXS = "phase_2_txs"
|
||||
# active rituals
|
||||
KEY_ACTIVE_RITUAL = "active_rituals"
|
||||
_KEY_ACTIVE_RITUAL = "active_rituals"
|
||||
|
||||
_KEYS = [
|
||||
KEY_TRANSCRIPT_TXS,
|
||||
KEY_VALIDATORS,
|
||||
KEY_AGGREGATED_TXS,
|
||||
KEY_ACTIVE_RITUAL,
|
||||
_KEY_PHASE_1_TXS,
|
||||
_KEY_VALIDATORS,
|
||||
_KEY_PHASE_2_TXS,
|
||||
_KEY_ACTIVE_RITUAL,
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
|
@ -38,45 +39,39 @@ class DKGStorage:
|
|||
continue
|
||||
|
||||
#
|
||||
# DKG Round 1 - Transcripts
|
||||
# DKG Phases
|
||||
#
|
||||
def store_transcript_txhash(self, ritual_id: int, txhash: HexBytes) -> None:
|
||||
self.data[self.KEY_TRANSCRIPT_TXS][ritual_id] = txhash
|
||||
@classmethod
|
||||
def __get_phase_key(cls, phase: int):
|
||||
if phase == PHASE1:
|
||||
return cls._KEY_PHASE_1_TXS
|
||||
return cls._KEY_PHASE_2_TXS
|
||||
|
||||
def clear_transcript_txhash(self, ritual_id: int, txhash: HexBytes) -> bool:
|
||||
if self.get_transcript_txhash(ritual_id) == txhash:
|
||||
del self.data[self.KEY_TRANSCRIPT_TXS][ritual_id]
|
||||
def store_ritual_phase_async_tx(self, phase_id: PhaseId, async_tx: AsyncTx):
|
||||
key = self.__get_phase_key(phase_id.phase)
|
||||
self.data[key][phase_id.ritual_id] = async_tx
|
||||
|
||||
def clear_ritual_phase_async_tx(self, phase_id: PhaseId, async_tx: AsyncTx) -> bool:
|
||||
key = self.__get_phase_key(phase_id.phase)
|
||||
if self.data[key][phase_id.ritual_id] is async_tx:
|
||||
del self.data[key][phase_id.ritual_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_transcript_txhash(self, ritual_id: int) -> Optional[HexBytes]:
|
||||
return self.data[self.KEY_TRANSCRIPT_TXS].get(ritual_id)
|
||||
def get_ritual_phase_async_tx(self, phase_id: PhaseId) -> Optional[AsyncTx]:
|
||||
key = self.__get_phase_key(phase_id.phase)
|
||||
return self.data[key].get(phase_id.ritual_id)
|
||||
|
||||
def store_validators(self, ritual_id: int, validators: List[Validator]) -> None:
|
||||
self.data[self.KEY_VALIDATORS][ritual_id] = list(validators)
|
||||
self.data[self._KEY_VALIDATORS][ritual_id] = list(validators)
|
||||
|
||||
def get_validators(self, ritual_id: int) -> Optional[List[Validator]]:
|
||||
validators = self.data[self.KEY_VALIDATORS].get(ritual_id)
|
||||
validators = self.data[self._KEY_VALIDATORS].get(ritual_id)
|
||||
if not validators:
|
||||
return None
|
||||
|
||||
return list(validators)
|
||||
|
||||
#
|
||||
# DKG Round 2 - Aggregation
|
||||
#
|
||||
def store_aggregation_txhash(self, ritual_id: int, txhash: HexBytes) -> None:
|
||||
self.data[self.KEY_AGGREGATED_TXS][ritual_id] = txhash
|
||||
|
||||
def clear_aggregated_txhash(self, ritual_id: int, txhash: HexBytes) -> bool:
|
||||
if self.get_aggregation_txhash(ritual_id) == txhash:
|
||||
del self.data[self.KEY_AGGREGATED_TXS][ritual_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_aggregation_txhash(self, ritual_id: int) -> Optional[HexBytes]:
|
||||
return self.data[self.KEY_AGGREGATED_TXS].get(ritual_id)
|
||||
|
||||
#
|
||||
# Active Rituals
|
||||
#
|
||||
|
@ -84,7 +79,7 @@ class DKGStorage:
|
|||
if active_ritual.total_aggregations != active_ritual.dkg_size:
|
||||
# safeguard against a non-active ritual being cached
|
||||
raise ValueError("Only active rituals can be cached")
|
||||
self.data[self.KEY_ACTIVE_RITUAL][active_ritual.id] = active_ritual
|
||||
self.data[self._KEY_ACTIVE_RITUAL][active_ritual.id] = active_ritual
|
||||
|
||||
def get_active_ritual(self, ritual_id: int) -> Optional[Coordinator.Ritual]:
|
||||
return self.data[self.KEY_ACTIVE_RITUAL].get(ritual_id)
|
||||
return self.data[self._KEY_ACTIVE_RITUAL].get(ritual_id)
|
||||
|
|
|
@ -141,16 +141,19 @@ def test_perform_round_1(
|
|||
lambda *args, **kwargs: Coordinator.RitualStatus.DKG_AWAITING_TRANSCRIPTS
|
||||
)
|
||||
|
||||
phase_id = PhaseId(ritual_id=0, phase=PHASE1)
|
||||
|
||||
assert (
|
||||
ursula.dkg_storage.get_ritual_phase_async_tx(phase_id=phase_id) is None
|
||||
), "no tx data as yet"
|
||||
|
||||
async_tx = ursula.perform_round_1(
|
||||
ritual_id=0, authority=random_address, participants=cohort, timestamp=0
|
||||
)
|
||||
|
||||
# ensure tx is tracked
|
||||
assert async_tx
|
||||
assert len(ursula.ritual_tracker.active_rituals) == 1
|
||||
|
||||
pid01 = PhaseId(ritual_id=0, phase=PHASE1)
|
||||
assert ursula.ritual_tracker.active_rituals[pid01]
|
||||
assert ursula.dkg_storage.get_ritual_phase_async_tx(phase_id=phase_id) is async_tx
|
||||
|
||||
# try again
|
||||
async_tx2 = ursula.perform_round_1(
|
||||
|
@ -158,23 +161,18 @@ def test_perform_round_1(
|
|||
)
|
||||
|
||||
assert async_tx2 is async_tx
|
||||
assert len(ursula.ritual_tracker.active_rituals) == 1
|
||||
assert ursula.ritual_tracker.active_rituals[pid01] is async_tx2
|
||||
assert ursula.dkg_storage.get_ritual_phase_async_tx(phase_id=phase_id) is async_tx2
|
||||
|
||||
# participant already posted transcript
|
||||
participant = agent.get_participant(
|
||||
ritual_id=0, provider=ursula.checksum_address, transcript=False
|
||||
)
|
||||
participant.transcript = bytes(random_transcript)
|
||||
|
||||
# try submitting again
|
||||
result = ursula.perform_round_1(
|
||||
ritual_id=0, authority=random_address, participants=cohort, timestamp=0
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert len(ursula.ritual_tracker.active_rituals) == 1
|
||||
assert ursula.ritual_tracker.active_rituals[pid01]
|
||||
|
||||
# participant no longer already posted aggregated transcript
|
||||
participant.transcript = bytes()
|
||||
|
@ -183,8 +181,7 @@ def test_perform_round_1(
|
|||
)
|
||||
|
||||
assert async_tx3 is async_tx
|
||||
assert len(ursula.ritual_tracker.active_rituals) == 1
|
||||
assert ursula.ritual_tracker.active_rituals[pid01]
|
||||
assert ursula.dkg_storage.get_ritual_phase_async_tx(phase_id=phase_id) is async_tx3
|
||||
|
||||
|
||||
def test_perform_round_2(
|
||||
|
@ -241,27 +238,33 @@ def test_perform_round_2(
|
|||
agent.get_ritual_status = lambda *args, **kwargs: state
|
||||
ursula.perform_round_2(ritual_id=0, timestamp=0)
|
||||
|
||||
assert len(ursula.ritual_tracker.active_rituals) == 1
|
||||
pid01 = PhaseId(ritual_id=0, phase=PHASE1)
|
||||
assert ursula.ritual_tracker.active_rituals[pid01]
|
||||
phase_1_id = PhaseId(ritual_id=0, phase=PHASE1)
|
||||
assert ursula.dkg_storage.get_ritual_phase_async_tx(phase_1_id) is not None
|
||||
|
||||
# set correct state
|
||||
agent.get_ritual_status = (
|
||||
lambda *args, **kwargs: Coordinator.RitualStatus.DKG_AWAITING_AGGREGATIONS
|
||||
)
|
||||
|
||||
phase_2_id = PhaseId(ritual_id=0, phase=PHASE2)
|
||||
|
||||
assert (
|
||||
ursula.dkg_storage.get_ritual_phase_async_tx(phase_id=phase_2_id) is None
|
||||
), "no tx data as yet"
|
||||
|
||||
mocker.patch("nucypher.crypto.ferveo.dkg.verify_aggregate")
|
||||
async_tx = ursula.perform_round_2(ritual_id=0, timestamp=0)
|
||||
|
||||
# check async tx tracking
|
||||
assert len(ursula.ritual_tracker.active_rituals) == 2
|
||||
pid02 = PhaseId(ritual_id=0, phase=PHASE2)
|
||||
assert ursula.ritual_tracker.active_rituals[pid02]
|
||||
assert ursula.dkg_storage.get_ritual_phase_async_tx(phase_2_id) is async_tx
|
||||
assert (
|
||||
ursula.dkg_storage.get_ritual_phase_async_tx(phase_1_id) is not async_tx
|
||||
), "phase 1 separate from phase 2"
|
||||
|
||||
# trying again yields same tx
|
||||
async_tx2 = ursula.perform_round_2(ritual_id=0, timestamp=0)
|
||||
assert len(ursula.ritual_tracker.active_rituals) == 2
|
||||
assert async_tx2 is async_tx
|
||||
assert ursula.dkg_storage.get_ritual_phase_async_tx(phase_2_id) is async_tx2
|
||||
|
||||
# No action required
|
||||
participant = agent.get_participant(
|
||||
|
@ -275,3 +278,4 @@ def test_perform_round_2(
|
|||
participant.aggregated = False
|
||||
async_tx4 = ursula.perform_round_2(ritual_id=0, timestamp=0)
|
||||
assert async_tx4 is async_tx
|
||||
assert ursula.dkg_storage.get_ritual_phase_async_tx(phase_2_id) is async_tx4
|
||||
|
|
Loading…
Reference in New Issue