From 31715e55e44fffe8727793448de7c138a3229d62 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Thu, 28 May 2020 16:24:41 -0700 Subject: [PATCH 1/6] Simplify staker sampling and add unit tests for proper sampling distribution --- nucypher/blockchain/eth/agents.py | 92 ++++++++++++------- nucypher/policy/policies.py | 7 +- .../agents/test_sampling_distribution.py | 44 ++++++++- 3 files changed, 102 insertions(+), 41 deletions(-) diff --git a/nucypher/blockchain/eth/agents.py b/nucypher/blockchain/eth/agents.py index 277d697fb..1b28b1209 100644 --- a/nucypher/blockchain/eth/agents.py +++ b/nucypher/blockchain/eth/agents.py @@ -15,8 +15,9 @@ You should have received a copy of the GNU Affero General Public License along with nucypher. If not, see . """ +from bisect import bisect_right +from itertools import accumulate import random - import math import sys from constant_sorrow.constants import ( # type: ignore @@ -730,15 +731,11 @@ class StakingEscrowAgent(EthereumContractAgent): def sample(self, quantity: int, duration: int, - additional_ursulas: float = 1.5, - attempts: int = 5, pagination_size: Optional[int] = None ) -> List[ChecksumAddress]: """ Select n random Stakers, according to their stake distribution. - - The returned addresses are shuffled, so one can request more than needed and - throw away those which do not respond. + The returned addresses are shuffled. See full diagram here: https://github.com/nucypher/kms-whitepaper/blob/master/pdf/miners-ruler.pdf @@ -757,42 +754,30 @@ class StakingEscrowAgent(EthereumContractAgent): Only stakers which made a commitment to the current period (in the previous period) are used. """ - system_random = random.SystemRandom() n_tokens, stakers_map = self.get_all_active_stakers(periods=duration, pagination_size=pagination_size) + + # TODO: can be implemented as an iterator if necessary, where the user can + # sample addresses one by one without calling get_all_active_stakers() repeatedly. + if n_tokens == 0: raise self.NotEnoughStakers('There are no locked tokens for duration {}.'.format(duration)) - sample_size = quantity - for _ in range(attempts): - sample_size = math.ceil(sample_size * additional_ursulas) - points = sorted(system_random.randrange(n_tokens) for _ in range(sample_size)) - self.log.debug(f"Sampling {sample_size} stakers with random points: {points}") + if quantity > len(stakers_map): + raise self.NotEnoughStakers(f'Cannot sample {quantity} out of {len(stakers)} total stakers') - addresses = set() - stakers = list(stakers_map.items()) + addresses = list(stakers_map.keys()) + tokens = list(stakers_map.values()) + sampler = WeightedSampler(addresses, tokens) - point_index = 0 - sum_of_locked_tokens = 0 - staker_index = 0 - stakers_len = len(stakers) - while staker_index < stakers_len and point_index < sample_size: - current_staker = stakers[staker_index][0] - staker_tokens = stakers[staker_index][1] - next_sum_value = sum_of_locked_tokens + staker_tokens + system_random = random.SystemRandom() + sampled_addresses = sampler.sample_no_replacement(system_random, quantity) - point = points[point_index] - if sum_of_locked_tokens <= point < next_sum_value: - addresses.add(to_checksum_address(current_staker)) - point_index += 1 - else: - staker_index += 1 - sum_of_locked_tokens = next_sum_value + # Randomize the output to avoid the largest stakers always being the first in the list + system_random.shuffle(sampled_addresses) # inplace - self.log.debug(f"Sampled {len(addresses)} stakers: {list(addresses)}") - if len(addresses) >= quantity: - return system_random.sample(addresses, quantity) + self.log.debug(f"Sampled {len(addresses)} stakers: {list(sampled_addresses)}") - raise self.NotEnoughStakers('Selection failed after {} attempts'.format(attempts)) + return sampled_addresses @contract_api(CONTRACT_CALL) def get_completed_work(self, bidder_address: ChecksumAddress) -> Work: @@ -1650,3 +1635,44 @@ class ContractAgency: agent_class: Type[EthereumContractAgent] = getattr(agents_module, agent_name) agent: EthereumContractAgent = cls.get_agent(agent_class=agent_class, registry=registry, provider_uri=provider_uri) return agent + + +class WeightedSampler: + """ + Samples random elements with probabilities proportioinal to given weights. + """ + + def __init__(self, elements: Iterable, weights: Iterable[int]): + assert len(elements) == len(weights) + self.totals = list(accumulate(weights)) + self.elements = elements + + def sample_no_replacement(self, rng, quantity: int) -> list: + """ + Samples ``quantity`` of elements from the internal array. + The probablity of an element to appear is proportional + to the weight provided to the constructor. + + The elements will not repeat; every time an element is sampled its weight is set to 0. + (does not mutate the object and only applies to the current invocation of the method). + """ + + if quantity > len(self.totals): + raise ValueError("Cannot sample more than the total amount of elements without replacement") + + totals = self.totals.copy() + samples = [] + + for i in range(quantity): + position = rng.randint(0, totals[-1] - 1) + idx = bisect_right(totals, position) + samples.append(self.elements[idx]) + + # Adjust the totals so that they correspond + # to the weight of the element `idx` being set to 0. + prev_total = totals[idx - 1] if idx > 0 else 0 + weight = totals[idx] - prev_total + for j in range(idx, len(totals)): + totals[j] -= weight + + return samples diff --git a/nucypher/policy/policies.py b/nucypher/policy/policies.py index 22b154f99..363ff35a0 100644 --- a/nucypher/policy/policies.py +++ b/nucypher/policy/policies.py @@ -205,9 +205,6 @@ class Policy(ABC): self.treasure_map = TreasureMap(m=m) self.expiration = expiration - # Keep track of this stuff - self.selection_buffer = 1 - self._accepted_arrangements = set() # type: Set[Arrangement] self._rejected_arrangements = set() # type: Set[Arrangement] self._spare_candidates = set() # type: Set[Ursula] @@ -532,7 +529,6 @@ class BlockchainPolicy(Policy): super().__init__(alice=alice, expiration=expiration, *args, **kwargs) - self.selection_buffer = 1.5 self.validate_fee_value() def validate_fee_value(self) -> None: @@ -618,8 +614,7 @@ class BlockchainPolicy(Policy): selected_addresses = set() try: sampled_addresses = self.alice.recruit(quantity=quantity, - duration=self.duration_periods, - additional_ursulas=self.selection_buffer) + duration=self.duration_periods) except StakingEscrowAgent.NotEnoughStakers as e: error = f"Cannot create policy with {quantity} arrangements: {e}" raise self.NotEnoughBlockchainUrsulas(error) diff --git a/tests/acceptance/blockchain/agents/test_sampling_distribution.py b/tests/acceptance/blockchain/agents/test_sampling_distribution.py index 3c6e3b032..d86dbc969 100644 --- a/tests/acceptance/blockchain/agents/test_sampling_distribution.py +++ b/tests/acceptance/blockchain/agents/test_sampling_distribution.py @@ -16,11 +16,13 @@ along with nucypher. If not, see . """ from collections import Counter +from itertools import permutations +import random import pytest from nucypher.blockchain.economics import BaseEconomics -from nucypher.blockchain.eth.agents import StakingEscrowAgent +from nucypher.blockchain.eth.agents import StakingEscrowAgent, WeightedSampler from nucypher.blockchain.eth.constants import NULL_ADDRESS, STAKING_ESCROW_CONTRACT_NAME @@ -115,7 +117,7 @@ def test_sampling_distribution(testerchain, token, deploy_contract, token_econom sampled, failed = 0, 0 while sampled < SAMPLES: try: - addresses = set(staking_agent.sample(quantity=quantity, additional_ursulas=1, duration=1)) + addresses = set(staking_agent.sample(quantity=quantity, duration=1)) addresses.discard(NULL_ADDRESS) except staking_agent.NotEnoughStakers: failed += 1 @@ -134,3 +136,41 @@ def test_sampling_distribution(testerchain, token, deploy_contract, token_econom assert abs_error < ERROR_TOLERANCE # TODO: Test something wrt to % of failed + + +def probability_reference_no_replacement(weights, idxs): + """ + The probability of drawing elements with (distinct) indices ``idxs`` (in given order), + given ``weights``. No replacement. + """ + assert len(set(idxs)) == len(idxs) + all_weights = sum(weights) + p = 1 + for idx in idxs: + p *= weights[idx] / all_weights + all_weights -= weights[idx] + return p + + +@pytest.mark.parametrize('sample_size', [1, 2, 3]) +def test_weighted_sampler(sample_size): + weights = [1, 9, 100, 2, 18, 70] + rng = random.SystemRandom() + counter = Counter() + + elements = list(range(len(weights))) + + samples = 100000 + sampler = WeightedSampler(elements, weights) + for i in range(samples): + sample_set = sampler.sample_no_replacement(rng, sample_size) + counter.update({tuple(sample_set): 1}) + + for idxs in permutations(elements, sample_size): + test_prob = counter[idxs] / samples + ref_prob = probability_reference_no_replacement(weights, idxs) + + # A rough estimate to check probabilities. + # A little too forgiving for samples with smaller probabilities, + # but can go up to 0.5 on occasion. + assert abs(test_prob - ref_prob) * samples**0.5 < 1 From 9ca6c6d8fa8600ed70ee3506f251fb945c7fac20 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Tue, 9 Jun 2020 12:51:27 -0700 Subject: [PATCH 2/6] Expose staker sampling as an iterator (ish) --- nucypher/blockchain/eth/actors.py | 14 +- nucypher/blockchain/eth/agents.py | 84 ++++++------ nucypher/network/nodes.py | 9 +- nucypher/policy/policies.py | 120 ++++++++++-------- .../agents/test_sampling_distribution.py | 3 +- .../agents/test_staking_escrow_agent.py | 4 +- 6 files changed, 124 insertions(+), 110 deletions(-) diff --git a/nucypher/blockchain/eth/actors.py b/nucypher/blockchain/eth/actors.py index 09b937362..b3206aade 100644 --- a/nucypher/blockchain/eth/actors.py +++ b/nucypher/blockchain/eth/actors.py @@ -44,7 +44,8 @@ from nucypher.blockchain.eth.agents import ( PolicyManagerAgent, PreallocationEscrowAgent, StakingEscrowAgent, - WorkLockAgent + WorkLockAgent, + StakersReservoir, ) from nucypher.blockchain.eth.constants import NULL_ADDRESS from nucypher.blockchain.eth.decorators import ( @@ -1545,16 +1546,11 @@ class BlockchainPolicyAuthor(NucypherTokenActor): payload = {**blockchain_payload, **policy_end_time} return payload - def recruit(self, quantity: int, **options) -> List[str]: + def get_stakers_reservoir(self, **options) -> StakersReservoir: """ - Uses sampling logic to gather stakers from the blockchain and - caches the resulting node ethereum addresses. - - :param quantity: Number of ursulas to sample from the blockchain. - + Get a sampler object containing the currently registered stakers. """ - staker_addresses = self.staking_agent.sample(quantity=quantity, **options) - return staker_addresses + return self.staking_agent.get_stakers_reservoir(**options) def create_policy(self, *args, **kwargs): """ diff --git a/nucypher/blockchain/eth/agents.py b/nucypher/blockchain/eth/agents.py index 1b28b1209..ee709d9ff 100644 --- a/nucypher/blockchain/eth/agents.py +++ b/nucypher/blockchain/eth/agents.py @@ -727,57 +727,32 @@ class StakingEscrowAgent(EthereumContractAgent): staker_address: ChecksumAddress = self.contract.functions.stakers(index).call() yield staker_address - @contract_api(CONTRACT_CALL) def sample(self, quantity: int, duration: int, pagination_size: Optional[int] = None ) -> List[ChecksumAddress]: - """ - Select n random Stakers, according to their stake distribution. - The returned addresses are shuffled. + reservoir = self.get_stakers_reservoir(duration=duration, pagination_size=pagination_size) + return reservoir.draw(quantity) - See full diagram here: https://github.com/nucypher/kms-whitepaper/blob/master/pdf/miners-ruler.pdf + @contract_api(CONTRACT_CALL) + def get_stakers_reservoir(self, + duration: int, + without: Iterable[ChecksumAddress] = [], + pagination_size: Optional[int] = None) -> 'StakersReservoir': + n_tokens, stakers_map = self.get_all_active_stakers(periods=duration, + pagination_size=pagination_size) - This method implements the Probability Proportional to Size (PPS) sampling algorithm. - In few words, the algorithm places in a line all active stakes that have locked tokens for - at least `duration` periods; a staker is selected if an input point is within its stake. - For example: + self.log.debug(f"Got {len(stakers_map)} stakers with {n_tokens} total tokens") - ``` - Stakes: |----- S0 ----|--------- S1 ---------|-- S2 --|---- S3 ---|-S4-|----- S5 -----| - Points: ....R0.......................R1..................R2...............R3........... - ``` + for address in without: + del stakers_map[address] - In this case, Stakers 0, 1, 3 and 5 will be selected. - - Only stakers which made a commitment to the current period (in the previous period) are used. - """ - - n_tokens, stakers_map = self.get_all_active_stakers(periods=duration, pagination_size=pagination_size) - - # TODO: can be implemented as an iterator if necessary, where the user can - # sample addresses one by one without calling get_all_active_stakers() repeatedly. - - if n_tokens == 0: + # TODO: or is it enough to just make sure the number of remaining stakers is non-zero? + if sum(stakers_map.values()) == 0: raise self.NotEnoughStakers('There are no locked tokens for duration {}.'.format(duration)) - if quantity > len(stakers_map): - raise self.NotEnoughStakers(f'Cannot sample {quantity} out of {len(stakers)} total stakers') - - addresses = list(stakers_map.keys()) - tokens = list(stakers_map.values()) - sampler = WeightedSampler(addresses, tokens) - - system_random = random.SystemRandom() - sampled_addresses = sampler.sample_no_replacement(system_random, quantity) - - # Randomize the output to avoid the largest stakers always being the first in the list - system_random.shuffle(sampled_addresses) # inplace - - self.log.debug(f"Sampled {len(addresses)} stakers: {list(sampled_addresses)}") - - return sampled_addresses + return StakersReservoir(stakers_map) @contract_api(CONTRACT_CALL) def get_completed_work(self, bidder_address: ChecksumAddress) -> Work: @@ -1657,7 +1632,10 @@ class WeightedSampler: (does not mutate the object and only applies to the current invocation of the method). """ - if quantity > len(self.totals): + if quantity == 0: + return [] + + if quantity > len(self): raise ValueError("Cannot sample more than the total amount of elements without replacement") totals = self.totals.copy() @@ -1676,3 +1654,27 @@ class WeightedSampler: totals[j] -= weight return samples + + def __len__(self): + return len(self.totals) + + +class StakersReservoir: + + def __init__(self, stakers_map): + addresses = list(stakers_map.keys()) + tokens = list(stakers_map.values()) + self._sampler = WeightedSampler(addresses, tokens) + self._rng = random.SystemRandom() + + def __len__(self): + return len(self._sampler) + + def draw(self, quantity): + if quantity > len(self): + raise StakingEscrowAgent.NotEnoughStakers(f'Cannot sample {quantity} out of {len(self)} total stakers') + + return self._sampler.sample_no_replacement(self._rng, quantity) + + def draw_at_most(self, quantity): + return self.draw(min(quantity, len(self))) diff --git a/nucypher/network/nodes.py b/nucypher/network/nodes.py index 42e60b7c6..aa3bf066c 100644 --- a/nucypher/network/nodes.py +++ b/nucypher/network/nodes.py @@ -34,7 +34,7 @@ from eth_utils import to_checksum_address from requests.exceptions import SSLError from twisted.internet import defer, reactor, task from twisted.internet.threads import deferToThread -from typing import Set, Tuple, Union +from typing import Set, Tuple, Union, Iterable from umbral.signing import Signature import nucypher @@ -603,9 +603,10 @@ class Learner: # TODO: Allow the user to set eagerness? 1712 self.learn_from_teacher_node(eager=False) - def learn_about_specific_nodes(self, addresses: Set): - self._node_ids_to_learn_about_immediately.update(addresses) # hmmmm - self.learn_about_nodes_now() + def learn_about_specific_nodes(self, addresses: Iterable): + if len(addresses) > 0: + self._node_ids_to_learn_about_immediately.update(addresses) # hmmmm + self.learn_about_nodes_now() # TODO: Dehydrate these next two methods. NRN diff --git a/nucypher/policy/policies.py b/nucypher/policy/policies.py index 363ff35a0..ffe72729c 100644 --- a/nucypher/policy/policies.py +++ b/nucypher/policy/policies.py @@ -15,6 +15,7 @@ You should have received a copy of the GNU Affero General Public License along with nucypher. If not, see . """ +import time import random from collections import OrderedDict, deque @@ -22,7 +23,7 @@ import maya from abc import ABC, abstractmethod from bytestring_splitter import BytestringSplitter, VariableLengthBytestring from constant_sorrow.constants import NOT_SIGNED, UNKNOWN_KFRAG -from typing import Generator, List, Set +from typing import Generator, List, Set, Optional from umbral.keys import UmbralPublicKey from umbral.kfrags import KFrag @@ -381,7 +382,7 @@ class Policy(ABC): def make_arrangements(self, network_middleware: RestMiddleware, - handpicked_ursulas: Set[Ursula] = None, + handpicked_ursulas: Optional[Set[Ursula]] = None, *args, **kwargs, ) -> None: @@ -408,11 +409,12 @@ class Policy(ABC): raise NotImplementedError @abstractmethod - def sample_essential(self, quantity: int, handpicked_ursulas: Set[Ursula] = None) -> Set[Ursula]: + def sample_essential(self, quantity: int, handpicked_ursulas: Set[Ursula]) -> Set[Ursula]: raise NotImplementedError - def sample(self, handpicked_ursulas: Set[Ursula] = None) -> Set[Ursula]: - selected_ursulas = set(handpicked_ursulas) if handpicked_ursulas else set() + def sample(self, handpicked_ursulas: Optional[Set[Ursula]] = None) -> Set[Ursula]: + handpicked_ursulas = handpicked_ursulas if handpicked_ursulas else set() + selected_ursulas = set(handpicked_ursulas) # Calculate the target sample quantity target_sample_quantity = self.n - len(selected_ursulas) @@ -475,11 +477,11 @@ class FederatedPolicy(Policy): "Pass them here as handpicked_ursulas.".format(self.n) raise self.MoreKFragsThanArrangements(error) # TODO: NotEnoughUrsulas where in the exception tree is this? - def sample_essential(self, quantity: int, handpicked_ursulas: Set[Ursula] = None) -> Set[Ursula]: + def sample_essential(self, quantity: int, handpicked_ursulas: Set[Ursula]) -> Set[Ursula]: known_nodes = self.alice.known_nodes if handpicked_ursulas: # Prevent re-sampling of handpicked ursulas. - known_nodes = set(known_nodes) - set(handpicked_ursulas) + known_nodes = set(known_nodes) - handpicked_ursulas sampled_ursulas = set(random.sample(k=quantity, population=list(known_nodes))) return sampled_ursulas @@ -572,57 +574,69 @@ class BlockchainPolicy(Policy): params = dict(rate=rate, value=value) return params - def __find_ursulas(self, - ether_addresses: List[str], - target_quantity: int, - timeout: int = 10) -> set: # TODO #843: Make timeout configurable + def sample_essential(self, + quantity: int, + handpicked_ursulas: Set[Ursula], + learner_timeout: int = 1, + timeout: int = 10) -> Set[Ursula]: - start_time = maya.now() # marker for timeout calculation + selected_ursulas = set(handpicked_ursulas) + quantity_remaining = quantity - found_ursulas, unknown_addresses = set(), deque() - while len(found_ursulas) < target_quantity: # until there are enough Ursulas + # Need to sample some stakers - delta = maya.now() - start_time # check for a timeout - if delta.total_seconds() >= timeout: - missing_nodes = ', '.join(a for a in unknown_addresses) - raise RuntimeError("Timed out after {} seconds; Cannot find {}.".format(timeout, missing_nodes)) - - # Select an ether_address: Prefer the selection pool, then unknowns queue - if ether_addresses: - ether_address = ether_addresses.pop() - else: - ether_address = unknown_addresses.popleft() - - try: - # Check if this is a known node. - selected_ursula = self.alice.known_nodes[ether_address] - - except KeyError: - # Unknown Node - self.alice.learn_about_specific_nodes({ether_address}) # enter address in learning loop - unknown_addresses.append(ether_address) - continue - - else: - # Known Node - found_ursulas.add(selected_ursula) # We already knew, or just learned about this ursula - - return found_ursulas - - def sample_essential(self, quantity: int, handpicked_ursulas: Set[Ursula] = None) -> Set[Ursula]: - # TODO: Prevent re-sampling of handpicked ursulas. - selected_addresses = set() - try: - sampled_addresses = self.alice.recruit(quantity=quantity, - duration=self.duration_periods) - except StakingEscrowAgent.NotEnoughStakers as e: - error = f"Cannot create policy with {quantity} arrangements: {e}" + handpicked_addresses = [ursula.checksum_address for ursula in handpicked_ursulas] + reservoir = self.alice.get_stakers_reservoir(duration=self.duration_periods, + without=handpicked_addresses) + if len(reservoir) < quantity_remaining: + error = f"Cannot create policy with {quantity} arrangements" raise self.NotEnoughBlockchainUrsulas(error) - # Capture the selection and search the network for those Ursulas - selected_addresses.update(sampled_addresses) - found_ursulas = self.__find_ursulas(sampled_addresses, quantity) - return found_ursulas + to_check = reservoir.draw(quantity_remaining) + + # Sample stakers in a loop and feed them to the learner to check + # until we have enough in selected_ursulas`. + + start_time = maya.now() + new_to_check = to_check + + while True: + + # Check if the sampled addresses are already known. + # If we're lucky, we won't have to wait for the learner iteration to finish. + known = list(filter(lambda x: x in self.alice.known_nodes, to_check)) + to_check = list(filter(lambda x: x not in self.alice.known_nodes, to_check)) + + known = known[:min(len(known), quantity_remaining)] # we only need so many + selected_ursulas.update([self.alice.known_nodes[address] for address in known]) + quantity_remaining -= len(known) + + if quantity_remaining == 0: + break + else: + new_to_check = reservoir.draw_at_most(quantity_remaining) + to_check.extend(new_to_check) + + # Feed newly sampled stakers to the learner + self.alice.learn_about_specific_nodes(new_to_check) + + # TODO: would be nice to wait for the learner to finish an iteration here, + # because if it hasn't, we really have nothing to do. + time.sleep(learner_timeout) + + delta = maya.now() - start_time + if delta.total_seconds() >= timeout: + still_checking = ', '.join(to_check) + raise RuntimeError(f"Timed out after {timeout} seconds; " + f"need {quantity} more, still checking {still_checking}.") + + found_ursulas = list(selected_ursulas) + + # Randomize the output to avoid the largest stakers always being the first in the list + system_random = random.SystemRandom() + system_random.shuffle(found_ursulas) # inplace + + return set(found_ursulas) def publish_to_blockchain(self) -> dict: diff --git a/tests/acceptance/blockchain/agents/test_sampling_distribution.py b/tests/acceptance/blockchain/agents/test_sampling_distribution.py index d86dbc969..68100b84b 100644 --- a/tests/acceptance/blockchain/agents/test_sampling_distribution.py +++ b/tests/acceptance/blockchain/agents/test_sampling_distribution.py @@ -117,7 +117,8 @@ def test_sampling_distribution(testerchain, token, deploy_contract, token_econom sampled, failed = 0, 0 while sampled < SAMPLES: try: - addresses = set(staking_agent.sample(quantity=quantity, duration=1)) + reservoir = staking_agent.get_stakers_reservoir(duration=1) + addresses = set(reservoir.draw(quantity)) addresses.discard(NULL_ADDRESS) except staking_agent.NotEnoughStakers: failed += 1 diff --git a/tests/acceptance/blockchain/agents/test_staking_escrow_agent.py b/tests/acceptance/blockchain/agents/test_staking_escrow_agent.py index 31f11f9ae..d835d7c26 100644 --- a/tests/acceptance/blockchain/agents/test_staking_escrow_agent.py +++ b/tests/acceptance/blockchain/agents/test_staking_escrow_agent.py @@ -292,13 +292,13 @@ def test_lock_restaking(agency, testerchain, test_registry): staking_agent = ContractAgency.get_agent(StakingEscrowAgent, registry=test_registry) current_period = staking_agent.get_current_period() terminal_period = current_period + 2 - + assert staking_agent.is_restaking(staker_account) assert not staking_agent.is_restaking_locked(staker_account) receipt = staking_agent.lock_restaking(staker_account, release_period=terminal_period) assert receipt['status'] == 1, "Transaction Rejected" assert staking_agent.is_restaking_locked(staker_account) - + testerchain.time_travel(periods=2) # Wait for re-staking lock to be released. assert not staking_agent.is_restaking_locked(staker_account) From c76b8f59b5da2df7bf79c91d07bc3a9c48b5c42a Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Wed, 29 Jul 2020 16:39:11 -0700 Subject: [PATCH 3/6] Implement RFCs --- nucypher/blockchain/eth/agents.py | 21 +++++++++++++-------- nucypher/policy/policies.py | 2 +- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/nucypher/blockchain/eth/agents.py b/nucypher/blockchain/eth/agents.py index ee709d9ff..349fc7155 100644 --- a/nucypher/blockchain/eth/agents.py +++ b/nucypher/blockchain/eth/agents.py @@ -740,17 +740,22 @@ class StakingEscrowAgent(EthereumContractAgent): duration: int, without: Iterable[ChecksumAddress] = [], pagination_size: Optional[int] = None) -> 'StakersReservoir': + n_tokens, stakers_map = self.get_all_active_stakers(periods=duration, - pagination_size=pagination_size) - - self.log.debug(f"Got {len(stakers_map)} stakers with {n_tokens} total tokens") + pagination_size=pagination_size) + filtered_out = 0 for address in without: - del stakers_map[address] + if address in stakers_map: + n_tokens -= stakers_map[address] + del stakers_map[address] + filtered_out += 1 - # TODO: or is it enough to just make sure the number of remaining stakers is non-zero? - if sum(stakers_map.values()) == 0: - raise self.NotEnoughStakers('There are no locked tokens for duration {}.'.format(duration)) + self.log.debug(f"Got {len(stakers_map)} stakers with {n_tokens} total tokens " + f"({filtered_out} filtered out)") + + if n_tokens == 0: + raise self.NotEnoughStakers(f'There are no locked tokens for duration {duration}.') return StakersReservoir(stakers_map) @@ -1614,7 +1619,7 @@ class ContractAgency: class WeightedSampler: """ - Samples random elements with probabilities proportioinal to given weights. + Samples random elements with probabilities proportional to given weights. """ def __init__(self, elements: Iterable, weights: Iterable[int]): diff --git a/nucypher/policy/policies.py b/nucypher/policy/policies.py index ffe72729c..3137f2f0e 100644 --- a/nucypher/policy/policies.py +++ b/nucypher/policy/policies.py @@ -628,7 +628,7 @@ class BlockchainPolicy(Policy): if delta.total_seconds() >= timeout: still_checking = ', '.join(to_check) raise RuntimeError(f"Timed out after {timeout} seconds; " - f"need {quantity} more, still checking {still_checking}.") + f"need {quantity_remaining} more, still checking {still_checking}.") found_ursulas = list(selected_ursulas) From f8c562ac572dbe539e10e49d1eef63859317c605 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Fri, 7 Aug 2020 18:37:58 -0700 Subject: [PATCH 4/6] Implement RFCs, part 2 --- nucypher/blockchain/eth/agents.py | 19 ++++--------------- nucypher/policy/policies.py | 16 ++++++++-------- .../agents/test_policy_manager_agent.py | 4 ++-- .../agents/test_sampling_distribution.py | 5 +++-- .../agents/test_staking_escrow_agent.py | 8 ++++---- 5 files changed, 21 insertions(+), 31 deletions(-) diff --git a/nucypher/blockchain/eth/agents.py b/nucypher/blockchain/eth/agents.py index 349fc7155..1849589a8 100644 --- a/nucypher/blockchain/eth/agents.py +++ b/nucypher/blockchain/eth/agents.py @@ -727,15 +727,6 @@ class StakingEscrowAgent(EthereumContractAgent): staker_address: ChecksumAddress = self.contract.functions.stakers(index).call() yield staker_address - def sample(self, - quantity: int, - duration: int, - pagination_size: Optional[int] = None - ) -> List[ChecksumAddress]: - reservoir = self.get_stakers_reservoir(duration=duration, pagination_size=pagination_size) - return reservoir.draw(quantity) - - @contract_api(CONTRACT_CALL) def get_stakers_reservoir(self, duration: int, without: Iterable[ChecksumAddress] = [], @@ -1622,8 +1613,8 @@ class WeightedSampler: Samples random elements with probabilities proportional to given weights. """ - def __init__(self, elements: Iterable, weights: Iterable[int]): - assert len(elements) == len(weights) + def __init__(self, weighted_elements: Dict[Any, int]): + elements, weights = zip(*weighted_elements.items()) self.totals = list(accumulate(weights)) self.elements = elements @@ -1666,10 +1657,8 @@ class WeightedSampler: class StakersReservoir: - def __init__(self, stakers_map): - addresses = list(stakers_map.keys()) - tokens = list(stakers_map.values()) - self._sampler = WeightedSampler(addresses, tokens) + def __init__(self, stakers_map: Dict[ChecksumAddress, int]): + self._sampler = WeightedSampler(stakers_map) self._rng = random.SystemRandom() def __len__(self): diff --git a/nucypher/policy/policies.py b/nucypher/policy/policies.py index 3137f2f0e..238c54202 100644 --- a/nucypher/policy/policies.py +++ b/nucypher/policy/policies.py @@ -413,7 +413,7 @@ class Policy(ABC): raise NotImplementedError def sample(self, handpicked_ursulas: Optional[Set[Ursula]] = None) -> Set[Ursula]: - handpicked_ursulas = handpicked_ursulas if handpicked_ursulas else set() + handpicked_ursulas = handpicked_ursulas or set() selected_ursulas = set(handpicked_ursulas) # Calculate the target sample quantity @@ -578,7 +578,7 @@ class BlockchainPolicy(Policy): quantity: int, handpicked_ursulas: Set[Ursula], learner_timeout: int = 1, - timeout: int = 10) -> Set[Ursula]: + timeout: int = 10) -> Set[Ursula]: # TODO #843: Make timeout configurable selected_ursulas = set(handpicked_ursulas) quantity_remaining = quantity @@ -592,10 +592,10 @@ class BlockchainPolicy(Policy): error = f"Cannot create policy with {quantity} arrangements" raise self.NotEnoughBlockchainUrsulas(error) - to_check = reservoir.draw(quantity_remaining) + to_check = set(reservoir.draw(quantity_remaining)) # Sample stakers in a loop and feed them to the learner to check - # until we have enough in selected_ursulas`. + # until we have enough in `selected_ursulas`. start_time = maya.now() new_to_check = to_check @@ -604,10 +604,10 @@ class BlockchainPolicy(Policy): # Check if the sampled addresses are already known. # If we're lucky, we won't have to wait for the learner iteration to finish. - known = list(filter(lambda x: x in self.alice.known_nodes, to_check)) - to_check = list(filter(lambda x: x not in self.alice.known_nodes, to_check)) + known = {x for x in to_check if x in self.alice.known_nodes} + to_check = to_check - known - known = known[:min(len(known), quantity_remaining)] # we only need so many + known = random.sample(known, min(len(known), quantity_remaining)) # we only need so many selected_ursulas.update([self.alice.known_nodes[address] for address in known]) quantity_remaining -= len(known) @@ -615,7 +615,7 @@ class BlockchainPolicy(Policy): break else: new_to_check = reservoir.draw_at_most(quantity_remaining) - to_check.extend(new_to_check) + to_check.update(new_to_check) # Feed newly sampled stakers to the learner self.alice.learn_about_specific_nodes(new_to_check) diff --git a/tests/acceptance/blockchain/agents/test_policy_manager_agent.py b/tests/acceptance/blockchain/agents/test_policy_manager_agent.py index b5b94909f..7aaa3390a 100644 --- a/tests/acceptance/blockchain/agents/test_policy_manager_agent.py +++ b/tests/acceptance/blockchain/agents/test_policy_manager_agent.py @@ -35,7 +35,7 @@ def policy_meta(testerchain, agency, token_economics, blockchain_ursulas): agent = policy_agent _policy_id = os.urandom(16) - staker_addresses = list(staking_agent.sample(quantity=3, duration=1)) + staker_addresses = list(staking_agent.get_stakers_reservoir(duration=1).draw(3)) number_of_periods = 10 now = testerchain.w3.eth.getBlock(block_identifier='latest').timestamp _txhash = agent.create_policy(policy_id=_policy_id, @@ -56,7 +56,7 @@ def test_create_policy(testerchain, agency, token_economics, mock_transacting_po mock_transacting_power_activation(account=testerchain.alice_account, password=INSECURE_DEVELOPMENT_PASSWORD) policy_id = os.urandom(16) - node_addresses = list(staking_agent.sample(quantity=3, duration=1)) + node_addresses = list(staking_agent.get_stakers_reservoir(duration=1).draw(3)) now = testerchain.w3.eth.getBlock(block_identifier='latest').timestamp receipt = agent.create_policy(policy_id=policy_id, author_address=testerchain.alice_account, diff --git a/tests/acceptance/blockchain/agents/test_sampling_distribution.py b/tests/acceptance/blockchain/agents/test_sampling_distribution.py index 68100b84b..33c6baab9 100644 --- a/tests/acceptance/blockchain/agents/test_sampling_distribution.py +++ b/tests/acceptance/blockchain/agents/test_sampling_distribution.py @@ -156,13 +156,14 @@ def probability_reference_no_replacement(weights, idxs): @pytest.mark.parametrize('sample_size', [1, 2, 3]) def test_weighted_sampler(sample_size): weights = [1, 9, 100, 2, 18, 70] + elements = list(range(len(weights))) rng = random.SystemRandom() counter = Counter() - elements = list(range(len(weights))) + weighted_elements = {element: weight for element, weight in zip(elements, weights)} samples = 100000 - sampler = WeightedSampler(elements, weights) + sampler = WeightedSampler(weighted_elements) for i in range(samples): sample_set = sampler.sample_no_replacement(rng, sample_size) counter.update({tuple(sample_set): 1}) diff --git a/tests/acceptance/blockchain/agents/test_staking_escrow_agent.py b/tests/acceptance/blockchain/agents/test_staking_escrow_agent.py index d835d7c26..c416baf17 100644 --- a/tests/acceptance/blockchain/agents/test_staking_escrow_agent.py +++ b/tests/acceptance/blockchain/agents/test_staking_escrow_agent.py @@ -152,19 +152,19 @@ def test_sample_stakers(agency): stakers_population = staking_agent.get_staker_population() with pytest.raises(StakingEscrowAgent.NotEnoughStakers): - staking_agent.sample(quantity=stakers_population + 1, duration=1) # One more than we have deployed + staking_agent.get_stakers_reservoir(duration=1).draw(stakers_population + 1) # One more than we have deployed - stakers = staking_agent.sample(quantity=3, duration=5) + stakers = staking_agent.get_stakers_reservoir(duration=5).draw(3) assert len(stakers) == 3 # Three... assert len(set(stakers)) == 3 # ...unique addresses # Same but with pagination - stakers = staking_agent.sample(quantity=3, duration=5, pagination_size=1) + stakers = staking_agent.get_stakers_reservoir(duration=5, pagination_size=1).draw(3) assert len(stakers) == 3 assert len(set(stakers)) == 3 light = staking_agent.blockchain.is_light staking_agent.blockchain.is_light = not light - stakers = staking_agent.sample(quantity=3, duration=5) + stakers = staking_agent.get_stakers_reservoir(duration=5).draw(3) assert len(stakers) == 3 assert len(set(stakers)) == 3 staking_agent.blockchain.is_light = light From 32371761e4644809c1be1ec717d0f4b81fe34646 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Tue, 11 Aug 2020 22:27:13 -0700 Subject: [PATCH 5/6] Implement RFCs, part 3 --- nucypher/policy/policies.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nucypher/policy/policies.py b/nucypher/policy/policies.py index 238c54202..81f5ca471 100644 --- a/nucypher/policy/policies.py +++ b/nucypher/policy/policies.py @@ -413,8 +413,7 @@ class Policy(ABC): raise NotImplementedError def sample(self, handpicked_ursulas: Optional[Set[Ursula]] = None) -> Set[Ursula]: - handpicked_ursulas = handpicked_ursulas or set() - selected_ursulas = set(handpicked_ursulas) + selected_ursulas = set(handpicked_ursulas) if handpicked_ursulas else set() # Calculate the target sample quantity target_sample_quantity = self.n - len(selected_ursulas) From 449139e0a14dd2a748fe2218fc5cda2239188eff Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Wed, 12 Aug 2020 15:40:17 -0700 Subject: [PATCH 6/6] Fix a logical mistake in sample() --- nucypher/policy/policies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nucypher/policy/policies.py b/nucypher/policy/policies.py index 81f5ca471..8681077f8 100644 --- a/nucypher/policy/policies.py +++ b/nucypher/policy/policies.py @@ -419,7 +419,7 @@ class Policy(ABC): target_sample_quantity = self.n - len(selected_ursulas) if target_sample_quantity > 0: sampled_ursulas = self.sample_essential(quantity=target_sample_quantity, - handpicked_ursulas=handpicked_ursulas) + handpicked_ursulas=selected_ursulas) selected_ursulas.update(sampled_ursulas) return selected_ursulas