Consolidate swarm/get_staking_providers into one method that returns an iterable.

Update tests.
pull/3345/head
derekpierre 2023-11-17 10:08:55 -05:00 committed by Derek Pierre
parent 74b5495321
commit ff51de058e
3 changed files with 19 additions and 35 deletions

View File

@ -489,14 +489,14 @@ class TACoChildApplicationAgent(StakerSamplingApplicationAgent):
return result
@contract_api(CONTRACT_CALL)
def get_staking_providers(self) -> List[ChecksumAddress]:
"""Returns a list of staking provider addresses"""
num_providers: int = self.get_staking_providers_population()
providers: List[ChecksumAddress] = [
self.contract.functions.stakingProviders(i).call()
for i in range(num_providers)
]
return providers
def get_staking_providers(self) -> Iterable[ChecksumAddress]:
"""Returns an iterable of staking provider addresses"""
num_providers = self.get_staking_providers_population()
for index in range(num_providers):
address: ChecksumAddress = self.contract.functions.stakingProviders(
index
).call()
yield address
@contract_api(CONTRACT_CALL)
def _get_active_staking_providers_raw(
@ -585,18 +585,10 @@ class TACoApplicationAgent(StakerSamplingApplicationAgent):
return result
@contract_api(CONTRACT_CALL)
def get_staking_providers(self) -> List[ChecksumAddress]:
"""Returns a list of staking provider addresses"""
num_providers: int = self.get_staking_providers_population()
providers: List[ChecksumAddress] = [
self.contract.functions.stakingProviders(i).call()
for i in range(num_providers)
]
return providers
@contract_api(CONTRACT_CALL)
def swarm(self) -> Iterable[ChecksumAddress]:
for index in range(self.get_staking_providers_population()):
def get_staking_providers(self) -> Iterable[ChecksumAddress]:
"""Returns an iterable of staking provider addresses"""
num_providers = self.get_staking_providers_population()
for index in range(num_providers):
address: ChecksumAddress = self.contract.functions.stakingProviders(
index
).call()

View File

@ -1,7 +1,6 @@
import random
import pytest
from eth_utils import is_address
from nucypher.blockchain.eth.agents import TACoApplicationAgent
from nucypher.blockchain.eth.constants import NULL_ADDRESS
@ -88,22 +87,13 @@ def test_get_staker_population(taco_application_agent, staking_providers):
)
def test_get_swarm(taco_application_agent, staking_providers):
swarm = taco_application_agent.swarm()
swarm_addresses = list(swarm)
assert len(swarm_addresses) == len(staking_providers) + 1
# Grab a staker address from the swarm
provider_addr = swarm_addresses[0]
assert isinstance(provider_addr, str)
assert is_address(provider_addr)
@pytest.mark.usefixtures("staking_providers", "ursulas")
def test_sample_staking_providers(taco_application_agent):
all_staking_providers = taco_application_agent.get_staking_providers()
all_staking_providers = list(taco_application_agent.get_staking_providers())
providers_population = taco_application_agent.get_staking_providers_population()
assert len(all_staking_providers) == providers_population
with pytest.raises(taco_application_agent.NotEnoughStakingProviders):
taco_application_agent.get_staking_provider_reservoir().draw(
providers_population + 1

View File

@ -44,7 +44,7 @@ def test_staking_provider_info(
ursulas,
get_random_checksum_address,
):
staking_providers = taco_child_application_agent.get_staking_providers()
staking_providers = list(taco_child_application_agent.get_staking_providers())
for ursula in ursulas:
provider_info = taco_child_application_agent.staking_provider_info(
@ -102,11 +102,13 @@ def test_get_staker_population(taco_child_application_agent, staking_providers):
@pytest.mark.usefixtures("staking_providers", "ursulas")
def test_sample_staking_providers(taco_child_application_agent):
all_staking_providers = taco_child_application_agent.get_staking_providers()
all_staking_providers = list(taco_child_application_agent.get_staking_providers())
providers_population = (
taco_child_application_agent.get_staking_providers_population()
)
assert len(all_staking_providers) == providers_population
with pytest.raises(taco_child_application_agent.NotEnoughStakingProviders):
taco_child_application_agent.get_staking_provider_reservoir().draw(
providers_population + 1