Commonize functionality for staker sampling for both TACoApplication and TACoChildApplication.

pull/3345/head
derekpierre 2023-11-13 15:04:31 -05:00 committed by Derek Pierre
parent 0f8ea5e067
commit 20f2702649
2 changed files with 147 additions and 99 deletions

View File

@ -1,6 +1,7 @@
import os
import random
import sys
from abc import abstractmethod
from bisect import bisect_right
from dataclasses import dataclass, field
from itertools import accumulate
@ -300,7 +301,131 @@ class SubscriptionManagerAgent(EthereumContractAgent):
return receipt
class TACoChildApplicationAgent(EthereumContractAgent):
class StakerSamplingApplicationAgent(EthereumContractAgent):
DEFAULT_PROVIDERS_PAGINATION_SIZE_LIGHT_NODE = int(
os.environ.get(NUCYPHER_ENVVAR_STAKING_PROVIDERS_PAGINATION_SIZE_LIGHT_NODE, default=30))
DEFAULT_PROVIDERS_PAGINATION_SIZE = int(
os.environ.get(NUCYPHER_ENVVAR_STAKING_PROVIDERS_PAGINATION_SIZE, default=1000))
class NotEnoughStakingProviders(Exception):
pass
@abstractmethod
def _get_active_staking_providers_raw(self, start_index: int, max_results: int) -> Tuple[int, List[bytes]]:
raise NotImplementedError
@abstractmethod
def get_staking_providers_population(self) -> int:
raise NotImplementedError
def get_all_active_staking_providers(
self, pagination_size: Optional[int] = None
) -> Tuple[types.TuNits, Dict[ChecksumAddress, types.TuNits]]:
n_tokens, staking_providers = self._get_active_stakers(
pagination_size=pagination_size
)
return n_tokens, staking_providers
@contract_api(CONTRACT_CALL)
def get_active_staking_providers(
self, start_index: int, max_results: int
) -> Tuple[types.TuNits, Dict[ChecksumAddress, types.TuNits]]:
active_staking_providers_info = self._get_active_staking_providers_raw(start_index, max_results)
authorized_tokens, staking_providers = self._process_active_staker_info(
active_staking_providers_info
)
return authorized_tokens, staking_providers
def get_staking_provider_reservoir(self,
without: Iterable[ChecksumAddress] = None,
pagination_size: Optional[int] = None
) -> 'StakingProvidersReservoir':
# pagination_size = pagination_size or self.get_staking_providers_population()
n_tokens, stake_provider_map = self.get_all_active_staking_providers(
pagination_size=pagination_size
)
if n_tokens == 0:
raise self.NotEnoughStakingProviders("There are no locked tokens.")
filtered_out = 0
if without:
for address in without:
if address in stake_provider_map:
n_tokens -= stake_provider_map[address]
del stake_provider_map[address]
filtered_out += 1
self.log.debug(f"Got {len(stake_provider_map)} staking providers with {n_tokens} total tokens "
f"({filtered_out} filtered out)")
return StakingProvidersReservoir(stake_provider_map)
@staticmethod
def _process_active_staker_info(active_staking_providers_info: Tuple[int, List[bytes]]) -> Tuple[types.TuNits, Dict[ChecksumAddress, types.TuNits]]:
total_authorized_tokens, staking_providers_info = active_staking_providers_info
staking_providers = dict()
for info in staking_providers_info:
staking_provider_address = to_checksum_address(info[0:20])
staking_provider_authorized_tokens = to_int(info[20:32])
staking_providers[staking_provider_address] = types.TuNits(
staking_provider_authorized_tokens
)
return types.TuNits(total_authorized_tokens), staking_providers
def _get_active_stakers(
self,
pagination_size: Optional[int] = None):
if pagination_size is None:
pagination_size = self.DEFAULT_PROVIDERS_PAGINATION_SIZE_LIGHT_NODE if self.blockchain.is_light else self.DEFAULT_PROVIDERS_PAGINATION_SIZE
self.log.debug(f"Defaulting to pagination size {pagination_size}")
elif pagination_size < 0:
raise ValueError("Pagination size must be >= 0")
if pagination_size > 0:
num_providers: int = self.get_staking_providers_population()
start_index: int = 0
n_tokens = types.TuNits(0)
staking_providers: Dict[ChecksumAddress, types.TuNits] = dict()
attempts: int = 0
while start_index < num_providers:
try:
attempts += 1
(
batch_authorized_tokens,
batch_staking_providers,
) = self.get_active_staking_providers(start_index, pagination_size)
except Exception as e:
if 'timeout' not in str(e):
# exception unrelated to pagination size and timeout
raise e
elif pagination_size == 1 or attempts >= 3:
# we tried
raise e
else:
# reduce pagination size and retry
old_pagination_size = pagination_size
pagination_size = old_pagination_size // 2
self.log.debug(
f"Failed staking providers sampling using pagination size = {old_pagination_size}."
f"Retrying with size {pagination_size}")
else:
n_tokens = n_tokens + batch_authorized_tokens
staking_providers.update(batch_staking_providers)
start_index += pagination_size
else:
n_tokens, staking_providers = self.get_active_staking_providers(
start_index=0, max_results=0
)
return n_tokens, staking_providers
class TACoChildApplicationAgent(StakerSamplingApplicationAgent):
contract_name: str = TACO_CHILD_APPLICATION_CONTRACT_NAME
class StakingProviderInfo(NamedTuple):
@ -327,6 +452,7 @@ class TACoChildApplicationAgent(EthereumContractAgent):
result = self.contract.functions.stakingProviderInfo(staking_provider).call()
return TACoChildApplicationAgent.StakingProviderInfo(*result)
@contract_api(CONTRACT_CALL)
def is_operator_confirmed(self, operator_address: ChecksumAddress) -> bool:
staking_provider = self.staking_provider_from_operator(operator_address)
if staking_provider == NULL_ADDRESS:
@ -336,20 +462,24 @@ class TACoChildApplicationAgent(EthereumContractAgent):
return staking_provider_info.operator_confirmed
class TACoApplicationAgent(EthereumContractAgent):
contract_name: str = TACO_APPLICATION_CONTRACT_NAME
@contract_api(CONTRACT_CALL)
def _get_active_staking_providers_raw(self, start_index: int, max_results: int) -> Tuple[int, List[bytes]]:
active_staking_providers_info = (
self.contract.functions.getActiveStakingProviders(
start_index, max_results
).call()
)
return active_staking_providers_info
DEFAULT_PROVIDERS_PAGINATION_SIZE_LIGHT_NODE = int(os.environ.get(NUCYPHER_ENVVAR_STAKING_PROVIDERS_PAGINATION_SIZE_LIGHT_NODE, default=30))
DEFAULT_PROVIDERS_PAGINATION_SIZE = int(os.environ.get(NUCYPHER_ENVVAR_STAKING_PROVIDERS_PAGINATION_SIZE, default=1000))
class TACoApplicationAgent(StakerSamplingApplicationAgent):
contract_name: str = TACO_APPLICATION_CONTRACT_NAME
class StakingProviderInfo(NamedTuple):
operator: ChecksumAddress
operator_confirmed: bool
operator_start_timestamp: int
class NotEnoughStakingProviders(Exception):
pass
class OperatorInfo(NamedTuple):
address: ChecksumAddress
confirmed: bool
@ -415,26 +545,6 @@ class TACoApplicationAgent(EthereumContractAgent):
providers: List[ChecksumAddress] = [self.contract.functions.stakingProviders(i).call() for i in range(num_providers)]
return providers
@contract_api(CONTRACT_CALL)
def get_active_staking_providers(
self, start_index: int, max_results: int
) -> Tuple[types.TuNits, Dict[ChecksumAddress, types.TuNits]]:
active_staking_providers_info = (
self.contract.functions.getActiveStakingProviders(
start_index, max_results
).call()
)
total_authorized_tokens, staking_providers_info = active_staking_providers_info
staking_providers = dict()
for info in staking_providers_info:
staking_provider_address = to_checksum_address(info[0:20])
staking_provider_authorized_tokens = to_int(info[20:32])
staking_providers[staking_provider_address] = types.TuNits(
staking_provider_authorized_tokens
)
return types.TuNits(total_authorized_tokens), staking_providers
@contract_api(CONTRACT_CALL)
def swarm(self) -> Iterable[ChecksumAddress]:
for index in range(self.get_staking_providers_population()):
@ -442,75 +552,13 @@ class TACoApplicationAgent(EthereumContractAgent):
yield address
@contract_api(CONTRACT_CALL)
def get_all_active_staking_providers(
self, pagination_size: Optional[int] = None
) -> Tuple[types.TuNits, Dict[ChecksumAddress, types.TuNits]]:
if pagination_size is None:
pagination_size = self.DEFAULT_PROVIDERS_PAGINATION_SIZE_LIGHT_NODE if self.blockchain.is_light else self.DEFAULT_PROVIDERS_PAGINATION_SIZE
self.log.debug(f"Defaulting to pagination size {pagination_size}")
elif pagination_size < 0:
raise ValueError("Pagination size must be >= 0")
if pagination_size > 0:
num_providers: int = self.get_staking_providers_population()
start_index: int = 0
n_tokens = types.TuNits(0)
staking_providers: Dict[ChecksumAddress, types.TuNits] = dict()
attempts: int = 0
while start_index < num_providers:
try:
attempts += 1
(
batch_authorized_tokens,
batch_staking_providers,
) = self.get_active_staking_providers(start_index, pagination_size)
except Exception as e:
if 'timeout' not in str(e):
# exception unrelated to pagination size and timeout
raise e
elif pagination_size == 1 or attempts >= 3:
# we tried
raise e
else:
# reduce pagination size and retry
old_pagination_size = pagination_size
pagination_size = old_pagination_size // 2
self.log.debug(f"Failed staking providers sampling using pagination size = {old_pagination_size}."
f"Retrying with size {pagination_size}")
else:
n_tokens = n_tokens + batch_authorized_tokens
staking_providers.update(batch_staking_providers)
start_index += pagination_size
else:
n_tokens, staking_providers = self.get_active_staking_providers(
start_index=0, max_results=0
)
return n_tokens, staking_providers
def get_staking_provider_reservoir(self,
without: Iterable[ChecksumAddress] = None,
pagination_size: Optional[int] = None
) -> 'StakingProvidersReservoir':
# pagination_size = pagination_size or self.get_staking_providers_population()
n_tokens, stake_provider_map = self.get_all_active_staking_providers(pagination_size=pagination_size)
filtered_out = 0
if without:
for address in without:
if address in stake_provider_map:
n_tokens -= stake_provider_map[address]
del stake_provider_map[address]
filtered_out += 1
self.log.debug(f"Got {len(stake_provider_map)} staking providers with {n_tokens} total tokens "
f"({filtered_out} filtered out)")
if n_tokens == 0:
raise self.NotEnoughStakingProviders("There are no locked tokens.")
return StakingProvidersReservoir(stake_provider_map)
def _get_active_staking_providers_raw(self, start_index: int, max_results: int) -> Tuple[int, List[bytes]]:
active_staking_providers_info = (
self.contract.functions.getActiveStakingProviders(
start_index, max_results
).call()
)
return active_staking_providers_info
#
# Transactions
@ -976,7 +1024,7 @@ class StakingProvidersReservoir:
def draw(self, quantity):
if quantity > len(self):
raise TACoApplicationAgent.NotEnoughStakingProviders(
raise StakerSamplingApplicationAgent.NotEnoughStakingProviders(
f"Cannot sample {quantity} out of {len(self)} total staking providers"
)

View File

@ -22,7 +22,7 @@ def make_staking_provider_reservoir(
without_set = set(include_addresses) | set(exclude_addresses or ())
try:
reservoir = application_agent.get_staking_provider_reservoir(without=without_set, pagination_size=pagination_size)
except TACoApplicationAgent.NotEnoughStakingProviders:
except application_agent.NotEnoughStakingProviders:
# TODO: do that in `get_staking_provider_reservoir()`?
reservoir = StakingProvidersReservoir({})