mirror of https://github.com/nucypher/nucypher.git
Commonize functionality for staker sampling for both TACoApplication and TACoChildApplication.
parent
0f8ea5e067
commit
20f2702649
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import random
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from bisect import bisect_right
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import accumulate
|
||||
|
@ -300,7 +301,131 @@ class SubscriptionManagerAgent(EthereumContractAgent):
|
|||
return receipt
|
||||
|
||||
|
||||
class TACoChildApplicationAgent(EthereumContractAgent):
|
||||
class StakerSamplingApplicationAgent(EthereumContractAgent):
|
||||
DEFAULT_PROVIDERS_PAGINATION_SIZE_LIGHT_NODE = int(
|
||||
os.environ.get(NUCYPHER_ENVVAR_STAKING_PROVIDERS_PAGINATION_SIZE_LIGHT_NODE, default=30))
|
||||
DEFAULT_PROVIDERS_PAGINATION_SIZE = int(
|
||||
os.environ.get(NUCYPHER_ENVVAR_STAKING_PROVIDERS_PAGINATION_SIZE, default=1000))
|
||||
|
||||
class NotEnoughStakingProviders(Exception):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_active_staking_providers_raw(self, start_index: int, max_results: int) -> Tuple[int, List[bytes]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_staking_providers_population(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_all_active_staking_providers(
|
||||
self, pagination_size: Optional[int] = None
|
||||
) -> Tuple[types.TuNits, Dict[ChecksumAddress, types.TuNits]]:
|
||||
n_tokens, staking_providers = self._get_active_stakers(
|
||||
pagination_size=pagination_size
|
||||
)
|
||||
return n_tokens, staking_providers
|
||||
|
||||
@contract_api(CONTRACT_CALL)
|
||||
def get_active_staking_providers(
|
||||
self, start_index: int, max_results: int
|
||||
) -> Tuple[types.TuNits, Dict[ChecksumAddress, types.TuNits]]:
|
||||
active_staking_providers_info = self._get_active_staking_providers_raw(start_index, max_results)
|
||||
|
||||
authorized_tokens, staking_providers = self._process_active_staker_info(
|
||||
active_staking_providers_info
|
||||
)
|
||||
return authorized_tokens, staking_providers
|
||||
|
||||
def get_staking_provider_reservoir(self,
|
||||
without: Iterable[ChecksumAddress] = None,
|
||||
pagination_size: Optional[int] = None
|
||||
) -> 'StakingProvidersReservoir':
|
||||
|
||||
# pagination_size = pagination_size or self.get_staking_providers_population()
|
||||
n_tokens, stake_provider_map = self.get_all_active_staking_providers(
|
||||
pagination_size=pagination_size
|
||||
)
|
||||
|
||||
if n_tokens == 0:
|
||||
raise self.NotEnoughStakingProviders("There are no locked tokens.")
|
||||
|
||||
filtered_out = 0
|
||||
if without:
|
||||
for address in without:
|
||||
if address in stake_provider_map:
|
||||
n_tokens -= stake_provider_map[address]
|
||||
del stake_provider_map[address]
|
||||
filtered_out += 1
|
||||
|
||||
self.log.debug(f"Got {len(stake_provider_map)} staking providers with {n_tokens} total tokens "
|
||||
f"({filtered_out} filtered out)")
|
||||
|
||||
return StakingProvidersReservoir(stake_provider_map)
|
||||
|
||||
@staticmethod
|
||||
def _process_active_staker_info(active_staking_providers_info: Tuple[int, List[bytes]]) -> Tuple[types.TuNits, Dict[ChecksumAddress, types.TuNits]]:
|
||||
total_authorized_tokens, staking_providers_info = active_staking_providers_info
|
||||
staking_providers = dict()
|
||||
for info in staking_providers_info:
|
||||
staking_provider_address = to_checksum_address(info[0:20])
|
||||
staking_provider_authorized_tokens = to_int(info[20:32])
|
||||
staking_providers[staking_provider_address] = types.TuNits(
|
||||
staking_provider_authorized_tokens
|
||||
)
|
||||
|
||||
return types.TuNits(total_authorized_tokens), staking_providers
|
||||
|
||||
def _get_active_stakers(
|
||||
self,
|
||||
pagination_size: Optional[int] = None):
|
||||
if pagination_size is None:
|
||||
pagination_size = self.DEFAULT_PROVIDERS_PAGINATION_SIZE_LIGHT_NODE if self.blockchain.is_light else self.DEFAULT_PROVIDERS_PAGINATION_SIZE
|
||||
self.log.debug(f"Defaulting to pagination size {pagination_size}")
|
||||
elif pagination_size < 0:
|
||||
raise ValueError("Pagination size must be >= 0")
|
||||
|
||||
if pagination_size > 0:
|
||||
num_providers: int = self.get_staking_providers_population()
|
||||
start_index: int = 0
|
||||
n_tokens = types.TuNits(0)
|
||||
staking_providers: Dict[ChecksumAddress, types.TuNits] = dict()
|
||||
attempts: int = 0
|
||||
while start_index < num_providers:
|
||||
try:
|
||||
attempts += 1
|
||||
(
|
||||
batch_authorized_tokens,
|
||||
batch_staking_providers,
|
||||
) = self.get_active_staking_providers(start_index, pagination_size)
|
||||
except Exception as e:
|
||||
if 'timeout' not in str(e):
|
||||
# exception unrelated to pagination size and timeout
|
||||
raise e
|
||||
elif pagination_size == 1 or attempts >= 3:
|
||||
# we tried
|
||||
raise e
|
||||
else:
|
||||
# reduce pagination size and retry
|
||||
old_pagination_size = pagination_size
|
||||
pagination_size = old_pagination_size // 2
|
||||
self.log.debug(
|
||||
f"Failed staking providers sampling using pagination size = {old_pagination_size}."
|
||||
f"Retrying with size {pagination_size}")
|
||||
else:
|
||||
n_tokens = n_tokens + batch_authorized_tokens
|
||||
staking_providers.update(batch_staking_providers)
|
||||
start_index += pagination_size
|
||||
|
||||
else:
|
||||
n_tokens, staking_providers = self.get_active_staking_providers(
|
||||
start_index=0, max_results=0
|
||||
)
|
||||
|
||||
return n_tokens, staking_providers
|
||||
|
||||
|
||||
class TACoChildApplicationAgent(StakerSamplingApplicationAgent):
|
||||
contract_name: str = TACO_CHILD_APPLICATION_CONTRACT_NAME
|
||||
|
||||
class StakingProviderInfo(NamedTuple):
|
||||
|
@ -327,6 +452,7 @@ class TACoChildApplicationAgent(EthereumContractAgent):
|
|||
result = self.contract.functions.stakingProviderInfo(staking_provider).call()
|
||||
return TACoChildApplicationAgent.StakingProviderInfo(*result)
|
||||
|
||||
@contract_api(CONTRACT_CALL)
|
||||
def is_operator_confirmed(self, operator_address: ChecksumAddress) -> bool:
|
||||
staking_provider = self.staking_provider_from_operator(operator_address)
|
||||
if staking_provider == NULL_ADDRESS:
|
||||
|
@ -336,20 +462,24 @@ class TACoChildApplicationAgent(EthereumContractAgent):
|
|||
return staking_provider_info.operator_confirmed
|
||||
|
||||
|
||||
class TACoApplicationAgent(EthereumContractAgent):
|
||||
contract_name: str = TACO_APPLICATION_CONTRACT_NAME
|
||||
@contract_api(CONTRACT_CALL)
|
||||
def _get_active_staking_providers_raw(self, start_index: int, max_results: int) -> Tuple[int, List[bytes]]:
|
||||
active_staking_providers_info = (
|
||||
self.contract.functions.getActiveStakingProviders(
|
||||
start_index, max_results
|
||||
).call()
|
||||
)
|
||||
return active_staking_providers_info
|
||||
|
||||
DEFAULT_PROVIDERS_PAGINATION_SIZE_LIGHT_NODE = int(os.environ.get(NUCYPHER_ENVVAR_STAKING_PROVIDERS_PAGINATION_SIZE_LIGHT_NODE, default=30))
|
||||
DEFAULT_PROVIDERS_PAGINATION_SIZE = int(os.environ.get(NUCYPHER_ENVVAR_STAKING_PROVIDERS_PAGINATION_SIZE, default=1000))
|
||||
|
||||
class TACoApplicationAgent(StakerSamplingApplicationAgent):
|
||||
contract_name: str = TACO_APPLICATION_CONTRACT_NAME
|
||||
|
||||
class StakingProviderInfo(NamedTuple):
|
||||
operator: ChecksumAddress
|
||||
operator_confirmed: bool
|
||||
operator_start_timestamp: int
|
||||
|
||||
class NotEnoughStakingProviders(Exception):
|
||||
pass
|
||||
|
||||
class OperatorInfo(NamedTuple):
|
||||
address: ChecksumAddress
|
||||
confirmed: bool
|
||||
|
@ -415,26 +545,6 @@ class TACoApplicationAgent(EthereumContractAgent):
|
|||
providers: List[ChecksumAddress] = [self.contract.functions.stakingProviders(i).call() for i in range(num_providers)]
|
||||
return providers
|
||||
|
||||
@contract_api(CONTRACT_CALL)
|
||||
def get_active_staking_providers(
|
||||
self, start_index: int, max_results: int
|
||||
) -> Tuple[types.TuNits, Dict[ChecksumAddress, types.TuNits]]:
|
||||
active_staking_providers_info = (
|
||||
self.contract.functions.getActiveStakingProviders(
|
||||
start_index, max_results
|
||||
).call()
|
||||
)
|
||||
total_authorized_tokens, staking_providers_info = active_staking_providers_info
|
||||
staking_providers = dict()
|
||||
for info in staking_providers_info:
|
||||
staking_provider_address = to_checksum_address(info[0:20])
|
||||
staking_provider_authorized_tokens = to_int(info[20:32])
|
||||
staking_providers[staking_provider_address] = types.TuNits(
|
||||
staking_provider_authorized_tokens
|
||||
)
|
||||
|
||||
return types.TuNits(total_authorized_tokens), staking_providers
|
||||
|
||||
@contract_api(CONTRACT_CALL)
|
||||
def swarm(self) -> Iterable[ChecksumAddress]:
|
||||
for index in range(self.get_staking_providers_population()):
|
||||
|
@ -442,75 +552,13 @@ class TACoApplicationAgent(EthereumContractAgent):
|
|||
yield address
|
||||
|
||||
@contract_api(CONTRACT_CALL)
|
||||
def get_all_active_staking_providers(
|
||||
self, pagination_size: Optional[int] = None
|
||||
) -> Tuple[types.TuNits, Dict[ChecksumAddress, types.TuNits]]:
|
||||
if pagination_size is None:
|
||||
pagination_size = self.DEFAULT_PROVIDERS_PAGINATION_SIZE_LIGHT_NODE if self.blockchain.is_light else self.DEFAULT_PROVIDERS_PAGINATION_SIZE
|
||||
self.log.debug(f"Defaulting to pagination size {pagination_size}")
|
||||
elif pagination_size < 0:
|
||||
raise ValueError("Pagination size must be >= 0")
|
||||
|
||||
if pagination_size > 0:
|
||||
num_providers: int = self.get_staking_providers_population()
|
||||
start_index: int = 0
|
||||
n_tokens = types.TuNits(0)
|
||||
staking_providers: Dict[ChecksumAddress, types.TuNits] = dict()
|
||||
attempts: int = 0
|
||||
while start_index < num_providers:
|
||||
try:
|
||||
attempts += 1
|
||||
(
|
||||
batch_authorized_tokens,
|
||||
batch_staking_providers,
|
||||
) = self.get_active_staking_providers(start_index, pagination_size)
|
||||
except Exception as e:
|
||||
if 'timeout' not in str(e):
|
||||
# exception unrelated to pagination size and timeout
|
||||
raise e
|
||||
elif pagination_size == 1 or attempts >= 3:
|
||||
# we tried
|
||||
raise e
|
||||
else:
|
||||
# reduce pagination size and retry
|
||||
old_pagination_size = pagination_size
|
||||
pagination_size = old_pagination_size // 2
|
||||
self.log.debug(f"Failed staking providers sampling using pagination size = {old_pagination_size}."
|
||||
f"Retrying with size {pagination_size}")
|
||||
else:
|
||||
n_tokens = n_tokens + batch_authorized_tokens
|
||||
staking_providers.update(batch_staking_providers)
|
||||
start_index += pagination_size
|
||||
|
||||
else:
|
||||
n_tokens, staking_providers = self.get_active_staking_providers(
|
||||
start_index=0, max_results=0
|
||||
)
|
||||
|
||||
return n_tokens, staking_providers
|
||||
|
||||
def get_staking_provider_reservoir(self,
|
||||
without: Iterable[ChecksumAddress] = None,
|
||||
pagination_size: Optional[int] = None
|
||||
) -> 'StakingProvidersReservoir':
|
||||
|
||||
# pagination_size = pagination_size or self.get_staking_providers_population()
|
||||
n_tokens, stake_provider_map = self.get_all_active_staking_providers(pagination_size=pagination_size)
|
||||
|
||||
filtered_out = 0
|
||||
if without:
|
||||
for address in without:
|
||||
if address in stake_provider_map:
|
||||
n_tokens -= stake_provider_map[address]
|
||||
del stake_provider_map[address]
|
||||
filtered_out += 1
|
||||
|
||||
self.log.debug(f"Got {len(stake_provider_map)} staking providers with {n_tokens} total tokens "
|
||||
f"({filtered_out} filtered out)")
|
||||
if n_tokens == 0:
|
||||
raise self.NotEnoughStakingProviders("There are no locked tokens.")
|
||||
|
||||
return StakingProvidersReservoir(stake_provider_map)
|
||||
def _get_active_staking_providers_raw(self, start_index: int, max_results: int) -> Tuple[int, List[bytes]]:
|
||||
active_staking_providers_info = (
|
||||
self.contract.functions.getActiveStakingProviders(
|
||||
start_index, max_results
|
||||
).call()
|
||||
)
|
||||
return active_staking_providers_info
|
||||
|
||||
#
|
||||
# Transactions
|
||||
|
@ -976,7 +1024,7 @@ class StakingProvidersReservoir:
|
|||
|
||||
def draw(self, quantity):
|
||||
if quantity > len(self):
|
||||
raise TACoApplicationAgent.NotEnoughStakingProviders(
|
||||
raise StakerSamplingApplicationAgent.NotEnoughStakingProviders(
|
||||
f"Cannot sample {quantity} out of {len(self)} total staking providers"
|
||||
)
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ def make_staking_provider_reservoir(
|
|||
without_set = set(include_addresses) | set(exclude_addresses or ())
|
||||
try:
|
||||
reservoir = application_agent.get_staking_provider_reservoir(without=without_set, pagination_size=pagination_size)
|
||||
except TACoApplicationAgent.NotEnoughStakingProviders:
|
||||
except application_agent.NotEnoughStakingProviders:
|
||||
# TODO: do that in `get_staking_provider_reservoir()`?
|
||||
reservoir = StakingProvidersReservoir({})
|
||||
|
||||
|
|
Loading…
Reference in New Issue