Don't store async_txs on the ActiveRitualTracker since it only cares about events; utilize existing DKG storage for operator. This will also allow for easier logic for clearing when no longer needed by the Operator.

Modify DKGStorage to store async txs.
Update tests accordingly.
pull/3476/head
derekpierre 2024-04-11 16:51:49 -04:00
parent 2be52480c9
commit dcb93abed8
No known key found for this signature in database
4 changed files with 69 additions and 73 deletions

View File

@ -419,7 +419,9 @@ class Operator(BaseActor):
transacting_power=self.transacting_power,
async_tx_hooks=async_tx_hooks,
)
self.ritual_tracker.active_rituals[identifier] = async_tx
self.dkg_storage.store_ritual_phase_async_tx(
phase_id=identifier, async_tx=async_tx
)
return async_tx
def publish_aggregated_transcript(
@ -445,8 +447,9 @@ class Operator(BaseActor):
transacting_power=self.transacting_power,
async_tx_hooks=async_tx_hooks,
)
self.ritual_tracker.active_rituals[identifier] = async_tx
self.dkg_storage.store_ritual_phase_async_tx(
phase_id=identifier, async_tx=async_tx
)
return async_tx
def _is_phase_1_action_required(self, ritual_id: int) -> bool:
@ -521,8 +524,10 @@ class Operator(BaseActor):
)
return
# check if there is already pending tx for this ritual + round combination
async_tx = self.ritual_tracker.active_rituals.get(PhaseId(ritual_id, PHASE1))
# check if there is already pending tx for this ritual + round combination
async_tx = self.dkg_storage.get_ritual_phase_async_tx(
phase_id=PhaseId(ritual_id, PHASE1)
)
if async_tx:
self.log.info(
f"Active ritual in progress: {self.transacting_power.account} has submitted tx "
@ -611,12 +616,12 @@ class Operator(BaseActor):
return
# check if there is a pending tx for this ritual + round combination
async_tx = self.ritual_tracker.active_rituals.get(
PhaseId(ritual_id=ritual_id, phase=PHASE2)
async_tx = self.dkg_storage.get_ritual_phase_async_tx(
phase_id=PhaseId(ritual_id, PHASE2)
)
if async_tx:
self.log.info(
f"Active ritual in progress Node {self.transacting_power.account} has submitted tx"
f"Active ritual in progress: {self.transacting_power.account} has submitted tx"
f"for ritual #{ritual_id}, phase #{PHASE2} (final: {async_tx.final})."
)
return async_tx

View File

@ -1,17 +1,15 @@
import datetime
import os
import time
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, List, Optional, Tuple
import maya
from atxm.tx import AsyncTx, FutureTx
from prometheus_client import REGISTRY, Gauge
from twisted.internet import threads
from web3.datastructures import AttributeDict
from nucypher.blockchain.eth.models import Coordinator
from nucypher.policy.conditions.utils import camel_case_to_snake
from nucypher.types import PhaseId
from nucypher.utilities.cache import TTLCache
from nucypher.utilities.events import EventScanner, JSONifiedState
from nucypher.utilities.logging import Logger
@ -123,8 +121,6 @@ class ActiveRitualTracker:
self.contract.events.EndRitual,
]
self.__phase_txs: Dict[PhaseId, FutureTx] = {}
# TODO: Remove the default JSON-RPC retry middleware
# as it correctly cannot handle eth_getLogs block range throttle down.
# self.web3.middleware_onion.remove(http_retry_request_middleware)
@ -167,10 +163,6 @@ class ActiveRitualTracker:
def contract(self):
return self.coordinator_agent.contract
@property
def active_rituals(self) -> Dict[PhaseId, AsyncTx]:
return self.__phase_txs
# TODO: should sample_window_size be additionally configurable/chain-dependent?
def _get_first_scan_start_block_number(self, sample_window_size: int = 100) -> int:
"""

View File

@ -1,30 +1,31 @@
from collections import defaultdict
from typing import List, Optional
from hexbytes import HexBytes
from atxm.tx import AsyncTx
from nucypher_core.ferveo import (
Validator,
)
from nucypher.blockchain.eth.models import Coordinator
from nucypher.blockchain.eth.models import PHASE1, Coordinator
from nucypher.types import PhaseId
class DKGStorage:
"""A simple in-memory storage for DKG data"""
# round 1
KEY_TRANSCRIPT_TXS = "transcript_tx_hashes"
KEY_VALIDATORS = "validators"
_KEY_PHASE_1_TXS = "phase_1_txs"
_KEY_VALIDATORS = "validators"
# round 2
KEY_AGGREGATED_TXS = "aggregation_tx_hashes"
_KEY_PHASE_2_TXS = "phase_2_txs"
# active rituals
KEY_ACTIVE_RITUAL = "active_rituals"
_KEY_ACTIVE_RITUAL = "active_rituals"
_KEYS = [
KEY_TRANSCRIPT_TXS,
KEY_VALIDATORS,
KEY_AGGREGATED_TXS,
KEY_ACTIVE_RITUAL,
_KEY_PHASE_1_TXS,
_KEY_VALIDATORS,
_KEY_PHASE_2_TXS,
_KEY_ACTIVE_RITUAL,
]
def __init__(self):
@ -38,45 +39,39 @@ class DKGStorage:
continue
#
# DKG Round 1 - Transcripts
# DKG Phases
#
def store_transcript_txhash(self, ritual_id: int, txhash: HexBytes) -> None:
self.data[self.KEY_TRANSCRIPT_TXS][ritual_id] = txhash
@classmethod
def __get_phase_key(cls, phase: int):
if phase == PHASE1:
return cls._KEY_PHASE_1_TXS
return cls._KEY_PHASE_2_TXS
def clear_transcript_txhash(self, ritual_id: int, txhash: HexBytes) -> bool:
if self.get_transcript_txhash(ritual_id) == txhash:
del self.data[self.KEY_TRANSCRIPT_TXS][ritual_id]
def store_ritual_phase_async_tx(self, phase_id: PhaseId, async_tx: AsyncTx):
key = self.__get_phase_key(phase_id.phase)
self.data[key][phase_id.ritual_id] = async_tx
def clear_ritual_phase_async_tx(self, phase_id: PhaseId, async_tx: AsyncTx) -> bool:
key = self.__get_phase_key(phase_id.phase)
if self.data[key][phase_id.ritual_id] is async_tx:
del self.data[key][phase_id.ritual_id]
return True
return False
def get_transcript_txhash(self, ritual_id: int) -> Optional[HexBytes]:
return self.data[self.KEY_TRANSCRIPT_TXS].get(ritual_id)
def get_ritual_phase_async_tx(self, phase_id: PhaseId) -> Optional[AsyncTx]:
key = self.__get_phase_key(phase_id.phase)
return self.data[key].get(phase_id.ritual_id)
def store_validators(self, ritual_id: int, validators: List[Validator]) -> None:
self.data[self.KEY_VALIDATORS][ritual_id] = list(validators)
self.data[self._KEY_VALIDATORS][ritual_id] = list(validators)
def get_validators(self, ritual_id: int) -> Optional[List[Validator]]:
validators = self.data[self.KEY_VALIDATORS].get(ritual_id)
validators = self.data[self._KEY_VALIDATORS].get(ritual_id)
if not validators:
return None
return list(validators)
#
# DKG Round 2 - Aggregation
#
def store_aggregation_txhash(self, ritual_id: int, txhash: HexBytes) -> None:
self.data[self.KEY_AGGREGATED_TXS][ritual_id] = txhash
def clear_aggregated_txhash(self, ritual_id: int, txhash: HexBytes) -> bool:
if self.get_aggregation_txhash(ritual_id) == txhash:
del self.data[self.KEY_AGGREGATED_TXS][ritual_id]
return True
return False
def get_aggregation_txhash(self, ritual_id: int) -> Optional[HexBytes]:
return self.data[self.KEY_AGGREGATED_TXS].get(ritual_id)
#
# Active Rituals
#
@ -84,7 +79,7 @@ class DKGStorage:
if active_ritual.total_aggregations != active_ritual.dkg_size:
# safeguard against a non-active ritual being cached
raise ValueError("Only active rituals can be cached")
self.data[self.KEY_ACTIVE_RITUAL][active_ritual.id] = active_ritual
self.data[self._KEY_ACTIVE_RITUAL][active_ritual.id] = active_ritual
def get_active_ritual(self, ritual_id: int) -> Optional[Coordinator.Ritual]:
return self.data[self.KEY_ACTIVE_RITUAL].get(ritual_id)
return self.data[self._KEY_ACTIVE_RITUAL].get(ritual_id)

View File

@ -141,16 +141,19 @@ def test_perform_round_1(
lambda *args, **kwargs: Coordinator.RitualStatus.DKG_AWAITING_TRANSCRIPTS
)
phase_id = PhaseId(ritual_id=0, phase=PHASE1)
assert (
ursula.dkg_storage.get_ritual_phase_async_tx(phase_id=phase_id) is None
), "no tx data as yet"
async_tx = ursula.perform_round_1(
ritual_id=0, authority=random_address, participants=cohort, timestamp=0
)
# ensure tx is tracked
assert async_tx
assert len(ursula.ritual_tracker.active_rituals) == 1
pid01 = PhaseId(ritual_id=0, phase=PHASE1)
assert ursula.ritual_tracker.active_rituals[pid01]
assert ursula.dkg_storage.get_ritual_phase_async_tx(phase_id=phase_id) is async_tx
# try again
async_tx2 = ursula.perform_round_1(
@ -158,23 +161,18 @@ def test_perform_round_1(
)
assert async_tx2 is async_tx
assert len(ursula.ritual_tracker.active_rituals) == 1
assert ursula.ritual_tracker.active_rituals[pid01] is async_tx2
assert ursula.dkg_storage.get_ritual_phase_async_tx(phase_id=phase_id) is async_tx2
# participant already posted transcript
participant = agent.get_participant(
ritual_id=0, provider=ursula.checksum_address, transcript=False
)
participant.transcript = bytes(random_transcript)
# try submitting again
result = ursula.perform_round_1(
ritual_id=0, authority=random_address, participants=cohort, timestamp=0
)
assert result is None
assert len(ursula.ritual_tracker.active_rituals) == 1
assert ursula.ritual_tracker.active_rituals[pid01]
# participant no longer already posted aggregated transcript
participant.transcript = bytes()
@ -183,8 +181,7 @@ def test_perform_round_1(
)
assert async_tx3 is async_tx
assert len(ursula.ritual_tracker.active_rituals) == 1
assert ursula.ritual_tracker.active_rituals[pid01]
assert ursula.dkg_storage.get_ritual_phase_async_tx(phase_id=phase_id) is async_tx3
def test_perform_round_2(
@ -241,27 +238,33 @@ def test_perform_round_2(
agent.get_ritual_status = lambda *args, **kwargs: state
ursula.perform_round_2(ritual_id=0, timestamp=0)
assert len(ursula.ritual_tracker.active_rituals) == 1
pid01 = PhaseId(ritual_id=0, phase=PHASE1)
assert ursula.ritual_tracker.active_rituals[pid01]
phase_1_id = PhaseId(ritual_id=0, phase=PHASE1)
assert ursula.dkg_storage.get_ritual_phase_async_tx(phase_1_id) is not None
# set correct state
agent.get_ritual_status = (
lambda *args, **kwargs: Coordinator.RitualStatus.DKG_AWAITING_AGGREGATIONS
)
phase_2_id = PhaseId(ritual_id=0, phase=PHASE2)
assert (
ursula.dkg_storage.get_ritual_phase_async_tx(phase_id=phase_2_id) is None
), "no tx data as yet"
mocker.patch("nucypher.crypto.ferveo.dkg.verify_aggregate")
async_tx = ursula.perform_round_2(ritual_id=0, timestamp=0)
# check async tx tracking
assert len(ursula.ritual_tracker.active_rituals) == 2
pid02 = PhaseId(ritual_id=0, phase=PHASE2)
assert ursula.ritual_tracker.active_rituals[pid02]
assert ursula.dkg_storage.get_ritual_phase_async_tx(phase_2_id) is async_tx
assert (
ursula.dkg_storage.get_ritual_phase_async_tx(phase_1_id) is not async_tx
), "phase 1 separate from phase 2"
# trying again yields same tx
async_tx2 = ursula.perform_round_2(ritual_id=0, timestamp=0)
assert len(ursula.ritual_tracker.active_rituals) == 2
assert async_tx2 is async_tx
assert ursula.dkg_storage.get_ritual_phase_async_tx(phase_2_id) is async_tx2
# No action required
participant = agent.get_participant(
@ -275,3 +278,4 @@ def test_perform_round_2(
participant.aggregated = False
async_tx4 = ursula.perform_round_2(ritual_id=0, timestamp=0)
assert async_tx4 is async_tx
assert ursula.dkg_storage.get_ritual_phase_async_tx(phase_2_id) is async_tx4