Fix linter errors.

pull/3345/head
derekpierre 2023-11-13 15:58:13 -05:00 committed by Derek Pierre
parent 2e64a657c5
commit 4ed1ef12c9
1 changed files with 58 additions and 26 deletions

View File

@ -303,15 +303,21 @@ class SubscriptionManagerAgent(EthereumContractAgent):
class StakerSamplingApplicationAgent(EthereumContractAgent):
DEFAULT_PROVIDERS_PAGINATION_SIZE_LIGHT_NODE = int(
os.environ.get(NUCYPHER_ENVVAR_STAKING_PROVIDERS_PAGINATION_SIZE_LIGHT_NODE, default=30))
os.environ.get(
NUCYPHER_ENVVAR_STAKING_PROVIDERS_PAGINATION_SIZE_LIGHT_NODE, default=30
)
)
DEFAULT_PROVIDERS_PAGINATION_SIZE = int(
os.environ.get(NUCYPHER_ENVVAR_STAKING_PROVIDERS_PAGINATION_SIZE, default=1000))
os.environ.get(NUCYPHER_ENVVAR_STAKING_PROVIDERS_PAGINATION_SIZE, default=1000)
)
class NotEnoughStakingProviders(Exception):
pass
@abstractmethod
def _get_active_staking_providers_raw(self, start_index: int, max_results: int) -> Tuple[int, List[bytes]]:
def _get_active_staking_providers_raw(
self, start_index: int, max_results: int
) -> Tuple[int, List[bytes]]:
raise NotImplementedError
@abstractmethod
@ -330,18 +336,20 @@ class StakerSamplingApplicationAgent(EthereumContractAgent):
def get_active_staking_providers(
self, start_index: int, max_results: int
) -> Tuple[types.TuNits, Dict[ChecksumAddress, types.TuNits]]:
active_staking_providers_info = self._get_active_staking_providers_raw(start_index, max_results)
active_staking_providers_info = self._get_active_staking_providers_raw(
start_index, max_results
)
authorized_tokens, staking_providers = self._process_active_staker_info(
active_staking_providers_info
)
return authorized_tokens, staking_providers
def get_staking_provider_reservoir(self,
without: Iterable[ChecksumAddress] = None,
pagination_size: Optional[int] = None
) -> 'StakingProvidersReservoir':
def get_staking_provider_reservoir(
self,
without: Iterable[ChecksumAddress] = None,
pagination_size: Optional[int] = None,
) -> "StakingProvidersReservoir":
# pagination_size = pagination_size or self.get_staking_providers_population()
n_tokens, stake_provider_map = self.get_all_active_staking_providers(
pagination_size=pagination_size
@ -358,13 +366,17 @@ class StakerSamplingApplicationAgent(EthereumContractAgent):
del stake_provider_map[address]
filtered_out += 1
self.log.debug(f"Got {len(stake_provider_map)} staking providers with {n_tokens} total tokens "
f"({filtered_out} filtered out)")
self.log.debug(
f"Got {len(stake_provider_map)} staking providers with {n_tokens} total tokens "
f"({filtered_out} filtered out)"
)
return StakingProvidersReservoir(stake_provider_map)
@staticmethod
def _process_active_staker_info(active_staking_providers_info: Tuple[int, List[bytes]]) -> Tuple[types.TuNits, Dict[ChecksumAddress, types.TuNits]]:
def _process_active_staker_info(
active_staking_providers_info: Tuple[int, List[bytes]]
) -> Tuple[types.TuNits, Dict[ChecksumAddress, types.TuNits]]:
total_authorized_tokens, staking_providers_info = active_staking_providers_info
staking_providers = dict()
for info in staking_providers_info:
@ -376,11 +388,13 @@ class StakerSamplingApplicationAgent(EthereumContractAgent):
return types.TuNits(total_authorized_tokens), staking_providers
def _get_active_stakers(
self,
pagination_size: Optional[int] = None):
def _get_active_stakers(self, pagination_size: Optional[int] = None):
if pagination_size is None:
pagination_size = self.DEFAULT_PROVIDERS_PAGINATION_SIZE_LIGHT_NODE if self.blockchain.is_light else self.DEFAULT_PROVIDERS_PAGINATION_SIZE
pagination_size = (
self.DEFAULT_PROVIDERS_PAGINATION_SIZE_LIGHT_NODE
if self.blockchain.is_light
else self.DEFAULT_PROVIDERS_PAGINATION_SIZE
)
self.log.debug(f"Defaulting to pagination size {pagination_size}")
elif pagination_size < 0:
raise ValueError("Pagination size must be >= 0")
@ -399,7 +413,7 @@ class StakerSamplingApplicationAgent(EthereumContractAgent):
batch_staking_providers,
) = self.get_active_staking_providers(start_index, pagination_size)
except Exception as e:
if 'timeout' not in str(e):
if "timeout" not in str(e):
# exception unrelated to pagination size and timeout
raise e
elif pagination_size == 1 or attempts >= 3:
@ -411,7 +425,8 @@ class StakerSamplingApplicationAgent(EthereumContractAgent):
pagination_size = old_pagination_size // 2
self.log.debug(
f"Failed staking providers sampling using pagination size = {old_pagination_size}."
f"Retrying with size {pagination_size}")
f"Retrying with size {pagination_size}"
)
else:
n_tokens = n_tokens + batch_authorized_tokens
staking_providers.update(batch_staking_providers)
@ -487,7 +502,9 @@ class TACoChildApplicationAgent(StakerSamplingApplicationAgent):
return providers
@contract_api(CONTRACT_CALL)
def _get_active_staking_providers_raw(self, start_index: int, max_results: int) -> Tuple[int, List[bytes]]:
def _get_active_staking_providers_raw(
self, start_index: int, max_results: int
) -> Tuple[int, List[bytes]]:
active_staking_providers_info = (
self.contract.functions.getActiveStakingProviders(
start_index, max_results
@ -520,13 +537,21 @@ class TACoApplicationAgent(StakerSamplingApplicationAgent):
return result
@contract_api(CONTRACT_CALL)
def get_staking_provider_from_operator(self, operator_address: ChecksumAddress) -> ChecksumAddress:
result = self.contract.functions.stakingProviderFromOperator(operator_address).call()
def get_staking_provider_from_operator(
self, operator_address: ChecksumAddress
) -> ChecksumAddress:
result = self.contract.functions.stakingProviderFromOperator(
operator_address
).call()
return result
@contract_api(CONTRACT_CALL)
def get_operator_from_staking_provider(self, staking_provider: ChecksumAddress) -> ChecksumAddress:
result = self.contract.functions.getOperatorFromStakingProvider(staking_provider).call()
def get_operator_from_staking_provider(
self, staking_provider: ChecksumAddress
) -> ChecksumAddress:
result = self.contract.functions.getOperatorFromStakingProvider(
staking_provider
).call()
return result
@contract_api(CONTRACT_CALL)
@ -566,17 +591,24 @@ class TACoApplicationAgent(StakerSamplingApplicationAgent):
def get_staking_providers(self) -> List[ChecksumAddress]:
"""Returns a list of staking provider addresses"""
num_providers: int = self.get_staking_providers_population()
providers: List[ChecksumAddress] = [self.contract.functions.stakingProviders(i).call() for i in range(num_providers)]
providers: List[ChecksumAddress] = [
self.contract.functions.stakingProviders(i).call()
for i in range(num_providers)
]
return providers
@contract_api(CONTRACT_CALL)
def swarm(self) -> Iterable[ChecksumAddress]:
for index in range(self.get_staking_providers_population()):
address: ChecksumAddress = self.contract.functions.stakingProviders(index).call()
address: ChecksumAddress = self.contract.functions.stakingProviders(
index
).call()
yield address
@contract_api(CONTRACT_CALL)
def _get_active_staking_providers_raw(self, start_index: int, max_results: int) -> Tuple[int, List[bytes]]:
def _get_active_staking_providers_raw(
self, start_index: int, max_results: int
) -> Tuple[int, List[bytes]]:
active_staking_providers_info = (
self.contract.functions.getActiveStakingProviders(
start_index, max_results