diff --git a/nucypher/blockchain/eth/agents.py b/nucypher/blockchain/eth/agents.py
index 94ff6cd83..8c29cd08f 100644
--- a/nucypher/blockchain/eth/agents.py
+++ b/nucypher/blockchain/eth/agents.py
@@ -16,6 +16,7 @@ along with nucypher. If not, see .
"""
+import math
import random
from typing import Generator, List, Tuple, Union
@@ -356,24 +357,20 @@ class StakingEscrowAgent(EthereumContractAgent):
"""
stakers_population = self.get_staker_population()
- if quantity > stakers_population:
- raise self.NotEnoughStakers(f'There are {stakers_population} published stakers, need a total of {quantity}.')
+ n_select = math.ceil(quantity * additional_ursulas) # Select more Ursulas
+ if n_select > stakers_population:
+ raise self.NotEnoughStakers(f'There are {stakers_population} active stakers, need at least {n_select}.')
system_random = random.SystemRandom()
- n_select = round(quantity*additional_ursulas) # Select more Ursulas
n_tokens = self.contract.functions.getAllLockedTokens(duration).call()
-
if n_tokens == 0:
raise self.NotEnoughStakers('There are no locked tokens for duration {}.'.format(duration))
for _ in range(attempts):
- points = [0] + sorted(system_random.randrange(n_tokens) for _ in range(n_select))
+ points = sorted(system_random.randrange(n_tokens) for _ in range(n_select))
+ self.log.debug(f"Sampling {n_select} stakers with random points: {points}")
- deltas = []
- for next_point, previous_point in zip(points[1:], points[:-1]):
- deltas.append(next_point - previous_point)
-
- addresses = set(self.contract.functions.sample(deltas, duration).call())
+ addresses = set(self.contract.functions.sample(points, duration).call())
addresses.discard(str(BlockchainInterface.NULL_ADDRESS))
if len(addresses) >= quantity:
diff --git a/nucypher/blockchain/eth/sol/source/contracts/StakingEscrow.sol b/nucypher/blockchain/eth/sol/source/contracts/StakingEscrow.sol
index 02039667a..7a518990f 100644
--- a/nucypher/blockchain/eth/sol/source/contracts/StakingEscrow.sol
+++ b/nucypher/blockchain/eth/sol/source/contracts/StakingEscrow.sol
@@ -839,16 +839,20 @@ contract StakingEscrow is Issuer {
/**
* @notice Get active stakers based on input points
- * @param _points Array of absolute values
+ * @param _points Array of absolute values. Must be sorted in ascending order.
* @param _periods Amount of periods for locked tokens calculation
*
- * @dev Sampling iterates over an array of stakers and input points.
- * Each iteration checks if the current point is contained within the current staker stake.
- * If the point is greater than or equal to the current sum of stakes,
- * this staker is skipped and the sum is increased by the value of next staker's stake.
- * If a point is less than the current sum of stakes, then the current staker is appended to the resulting array.
- * Secondly, the sum of stakes is decreased by a point;
- * The next iteration will check the next point for the difference.
+ * @dev This method implements the Probability Proportional to Size (PPS) sampling algorithm,
+ * but with the random input data provided in the _points array.
+ * In few words, the algorithm places in a line all active stakes that have locked tokens for
+ * at least _periods periods; a staker is selected if an input point is within its stake.
+ * For example:
+ *
+ * Stakes: |----- S0 ----|--------- S1 ---------|-- S2 --|---- S3 ---|-S4-|----- S5 -----|
+ * Points: ....R0.......................R1..................R2...............R3...........
+ *
+ * In this case, Stakers 0, 1, 3 and 5 will be selected.
+ *
* Only stakers which confirmed the current period (in the previous period) are used.
* If the number of points is more than the number of active stakers with suitable stakes,
* the last values in the resulting array will be zeros addresses.
@@ -862,31 +866,30 @@ contract StakingEscrow is Issuer {
uint16 nextPeriod = currentPeriod.add16(_periods);
result = new address[](_points.length);
+ uint256 previousPoint = 0;
uint256 pointIndex = 0;
uint256 sumOfLockedTokens = 0;
uint256 stakerIndex = 0;
- bool addMoreTokens = true;
while (stakerIndex < stakers.length && pointIndex < _points.length) {
address currentStaker = stakers[stakerIndex];
StakerInfo storage info = stakerInfo[currentStaker];
- uint256 point = _points[pointIndex];
if (info.confirmedPeriod1 != currentPeriod &&
info.confirmedPeriod2 != currentPeriod) {
stakerIndex += 1;
- addMoreTokens = true;
continue;
}
- if (addMoreTokens) {
- sumOfLockedTokens = sumOfLockedTokens.add(getLockedTokens(info, currentPeriod, nextPeriod));
- }
- if (sumOfLockedTokens > point) {
+ uint256 stakerTokens = getLockedTokens(info, currentPeriod, nextPeriod);
+ uint256 nextSumValue = sumOfLockedTokens.add(stakerTokens);
+
+ uint256 point = _points[pointIndex];
+ require(point >= previousPoint); // _points must be a sorted array
+ if (sumOfLockedTokens <= point && point < nextSumValue) {
result[pointIndex] = currentStaker;
- sumOfLockedTokens -= point;
pointIndex += 1;
- addMoreTokens = false;
+ previousPoint = point;
} else {
stakerIndex += 1;
- addMoreTokens = true;
+ sumOfLockedTokens = nextSumValue;
}
}
}
diff --git a/nucypher/policy/policies.py b/nucypher/policy/policies.py
index db1aefc8d..956f497d4 100644
--- a/nucypher/policy/policies.py
+++ b/nucypher/policy/policies.py
@@ -377,23 +377,16 @@ class Policy(ABC):
raise NotImplementedError
def sample(self, handpicked_ursulas: Set[Ursula] = None) -> Set[Ursula]:
- if not handpicked_ursulas:
- handpicked_ursulas = set()
- else:
- handpicked_ursulas = set(handpicked_ursulas)
+ selected_ursulas = set(handpicked_ursulas) if handpicked_ursulas else set()
# Calculate the target sample quantity
- ADDITIONAL_URSULAS = self.selection_buffer
- target_sample_quantity = self.n - len(handpicked_ursulas)
- actual_sample_quantity = math.ceil(target_sample_quantity * ADDITIONAL_URSULAS)
+ target_sample_quantity = self.n - len(selected_ursulas)
+ if target_sample_quantity > 0:
+ sampled_ursulas = self.sample_essential(quantity=target_sample_quantity,
+ handpicked_ursulas=handpicked_ursulas)
+ selected_ursulas.update(sampled_ursulas)
- if actual_sample_quantity > 0:
- selected_ursulas = self.sample_essential(quantity=actual_sample_quantity,
- handpicked_ursulas=handpicked_ursulas)
- handpicked_ursulas.update(selected_ursulas)
-
- final_ursulas = handpicked_ursulas
- return final_ursulas
+ return selected_ursulas
def _consider_arrangements(self,
network_middleware: RestMiddleware,
@@ -597,8 +590,9 @@ class BlockchainPolicy(Policy):
# TODO: Prevent re-sampling of handpicked ursulas.
selected_addresses = set()
try:
- # Sample by reading from the Blockchain
- sampled_addresses = self.alice.recruit(quantity=quantity, duration=self.duration_periods)
+ sampled_addresses = self.alice.recruit(quantity=quantity,
+ duration=self.duration_periods,
+ additional_ursulas=self.selection_buffer)
except StakingEscrowAgent.NotEnoughStakers as e:
error = f"Cannot create policy with {quantity} arrangements: {e}"
raise self.NotEnoughBlockchainUrsulas(error)
diff --git a/tests/blockchain/eth/contracts/main/staking_escrow/test_tracking.py b/tests/blockchain/eth/contracts/main/staking_escrow/test_sampling.py
similarity index 81%
rename from tests/blockchain/eth/contracts/main/staking_escrow/test_tracking.py
rename to tests/blockchain/eth/contracts/main/staking_escrow/test_sampling.py
index 0ea00c1dc..bacc0fee8 100644
--- a/tests/blockchain/eth/contracts/main/staking_escrow/test_tracking.py
+++ b/tests/blockchain/eth/contracts/main/staking_escrow/test_sampling.py
@@ -18,14 +18,13 @@ along with nucypher. If not, see .
import pytest
from eth_tester.exceptions import TransactionFailed
-from web3.contract import Contract
+from nucypher.blockchain.eth.interfaces import BlockchainInterface
@pytest.mark.slow
def test_sampling(testerchain, token, escrow_contract):
escrow = escrow_contract(5 * 10 ** 8)
- NULL_ADDR = '0x' + '0' * 40
- creator = testerchain.client.accounts[0]
+ creator = testerchain.etherbase_account
# Give Escrow tokens for reward and initialize contract
tx = token.functions.transfer(escrow.address, 10 ** 9).transact({'from': creator})
@@ -43,16 +42,18 @@ def test_sampling(testerchain, token, escrow_contract):
testerchain.wait_for_receipt(tx)
amount = amount // 2
- # Cant't use sample without points or with zero periods value
- with pytest.raises((TransactionFailed, ValueError)):
+ # Can't use sample without points, with zero periods value, or not sorted in ascending order.
+ with pytest.raises(TransactionFailed):
escrow.functions.sample([], 1).call()
- with pytest.raises((TransactionFailed, ValueError)):
+ with pytest.raises(TransactionFailed):
escrow.functions.sample([1], 0).call()
+ with pytest.raises(TransactionFailed):
+ escrow.functions.sample([3, 2, 1], 0).call()
# No stakers yet
addresses = escrow.functions.sample([1], 1).call()
assert 1 == len(addresses)
- assert NULL_ADDR == addresses[0]
+ assert BlockchainInterface.NULL_ADDRESS == addresses[0]
all_locked_tokens = 0
# All stakers lock tokens for different lock periods
@@ -74,7 +75,7 @@ def test_sampling(testerchain, token, escrow_contract):
# So sampling in current period is useless
addresses = escrow.functions.sample([1], 1).call()
assert 1 == len(addresses)
- assert NULL_ADDR == addresses[0]
+ assert BlockchainInterface.NULL_ADDRESS == addresses[0]
# Wait next period and check all locked tokens
testerchain.time_travel(hours=1)
@@ -89,35 +90,36 @@ def test_sampling(testerchain, token, escrow_contract):
testerchain.wait_for_receipt(tx)
# Sample one staker by value less than first staker's stake
- addresses = escrow.functions.sample([all_locked_tokens // 3], 1).call()
+ addresses = escrow.functions.sample([largest_locked - 1], 1).call()
assert 1 == len(addresses)
assert stakers[0] == addresses[0]
# Sample two stakers by values that are equal to first and second stakes
# In the result must be second and third stakers because of strict condition in the sampling
- # sumOfLockedTokens > point
- addresses = escrow.functions.sample([largest_locked, largest_locked // 2], 1).call()
+ addresses = escrow.functions.sample([largest_locked, largest_locked + largest_locked // 2], 1).call()
assert 2 == len(addresses)
assert stakers[1] == addresses[0]
assert stakers[2] == addresses[1]
+ # Sample two stakers by values that within first and second stakes
+ # In the result must be second and third stakers because of strict condition in the sampling
+ # sumOfLockedTokens > point
+ addresses = escrow.functions.sample([largest_locked - 1, largest_locked + 1], 1).call()
+ assert 2 == len(addresses)
+ assert stakers[0] == addresses[0]
+ assert stakers[1] == addresses[1]
+
# Sample staker by the max duration of the longest stake
# The result is the staker who has the longest stake
addresses = escrow.functions.sample([1], len(stakers)).call()
assert 1 == len(addresses)
assert stakers[-1] == addresses[0]
+
# Sample staker using the duration more than the longest stake
# The result is empty
addresses = escrow.functions.sample([1], len(stakers) + 1).call()
assert 1 == len(addresses)
- assert NULL_ADDR == addresses[0]
-
- # Sample by values that more than all locked tokens
- # Only one staker will be in the result
- addresses = escrow.functions.sample([largest_locked, largest_locked], 1).call()
- assert 2 == len(addresses)
- assert stakers[1] == addresses[0]
- assert NULL_ADDR == addresses[1]
+ assert BlockchainInterface.NULL_ADDRESS == addresses[0]
# Sample stakers by different durations and minimum value
# Each result is the first appropriate stake by length
@@ -127,9 +129,11 @@ def test_sampling(testerchain, token, escrow_contract):
assert stakers[index] == addresses[0]
# Sample all stakers by values as stake minus one
- # The result must contain all stakers because of condition sumOfLockedTokens > point
+ # The result must contain all stakers
points = [escrow.functions.getLockedTokens(staker).call() for staker in stakers]
points[0] = points[0] - 1
+ for i in range(1, len(points)):
+ points[i] += points[i-1]
addresses = escrow.functions.sample(points, 1).call()
assert stakers == addresses
diff --git a/tests/blockchain/eth/contracts/main/staking_escrow/test_sampling_distribution.py b/tests/blockchain/eth/contracts/main/staking_escrow/test_sampling_distribution.py
new file mode 100644
index 000000000..995488c7e
--- /dev/null
+++ b/tests/blockchain/eth/contracts/main/staking_escrow/test_sampling_distribution.py
@@ -0,0 +1,111 @@
+"""
+This file is part of nucypher.
+
+nucypher is free software: you can redistribute it and/or modify
+it under the terms of the GNU Affero General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+nucypher is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Affero General Public License for more details.
+
+You should have received a copy of the GNU Affero General Public License
+along with nucypher. If not, see .
+"""
+import pytest
+
+from nucypher.blockchain.eth.interfaces import BlockchainInterface
+from nucypher.blockchain.eth.constants import STAKING_ESCROW_CONTRACT_NAME
+
+
+# TODO: #1288 - Consider moving this test out from regular CI workflow to a scheduled workflow (e.g., nightly)
+@pytest.mark.slow
+def test_sampling_distribution(testerchain, token, deploy_contract):
+
+ #
+ # SETUP
+ #
+
+ max_allowed_locked_tokens = 5 * 10 ** 8
+ _staking_coefficient = 2 * 10 ** 7
+ contract, _ = deploy_contract(
+ contract_name=STAKING_ESCROW_CONTRACT_NAME,
+ _token=token.address,
+ _hoursPerPeriod=1,
+ _miningCoefficient=4 * _staking_coefficient,
+ _lockedPeriodsCoefficient=4,
+ _rewardedPeriods=4,
+ _minLockedPeriods=2,
+ _minAllowableLockedTokens=100,
+ _maxAllowableLockedTokens=max_allowed_locked_tokens,
+ _minWorkerPeriods=1
+ )
+
+ policy_manager, _ = deploy_contract(
+ 'PolicyManagerForStakingEscrowMock', token.address, contract.address
+ )
+ tx = contract.functions.setPolicyManager(policy_manager.address).transact()
+ testerchain.wait_for_receipt(tx)
+
+ # Travel to the start of the next period to prevent problems with unexpected overflow first period
+ testerchain.time_travel(hours=1)
+
+ escrow = contract
+ creator = testerchain.etherbase_account
+
+ # Give Escrow tokens for reward and initialize contract
+ tx = token.functions.transfer(escrow.address, 10 ** 9).transact({'from': creator})
+ testerchain.wait_for_receipt(tx)
+ tx = escrow.functions.initialize().transact({'from': creator})
+ testerchain.wait_for_receipt(tx)
+
+ stakers = testerchain.stakers_accounts
+ amount = token.functions.balanceOf(creator).call() // len(stakers)
+
+ # Airdrop
+ for staker in stakers:
+ tx = token.functions.transfer(staker, amount).transact({'from': creator})
+ testerchain.wait_for_receipt(tx)
+
+ all_locked_tokens = len(stakers) * amount
+ for staker in stakers:
+ balance = token.functions.balanceOf(staker).call()
+ tx = token.functions.approve(escrow.address, balance).transact({'from': staker})
+ testerchain.wait_for_receipt(tx)
+ tx = escrow.functions.deposit(balance, 10).transact({'from': staker})
+ testerchain.wait_for_receipt(tx)
+ tx = escrow.functions.setWorker(staker).transact({'from': staker})
+ testerchain.wait_for_receipt(tx)
+ tx = escrow.functions.confirmActivity().transact({'from': staker})
+ testerchain.wait_for_receipt(tx)
+
+ # Wait next period and check all locked tokens
+ testerchain.time_travel(hours=1)
+
+ #
+ # Test sampling distribution
+ #
+
+ ERROR_TOLERANCE = 0.05 # With this tolerance, all sampling ratios should between 5% and 15% (expected is 10%)
+ SAMPLES = 100
+ quantity = 3
+ import random
+ from collections import Counter
+
+ counter = Counter()
+ for i in range(SAMPLES):
+ points = sorted(random.SystemRandom().randrange(all_locked_tokens) for _ in range(quantity))
+ addresses = set(escrow.functions.sample(points, 1).call())
+ addresses.discard(BlockchainInterface.NULL_ADDRESS)
+ counter.update(addresses)
+
+ total_times = sum(counter.values())
+
+ expected = amount / all_locked_tokens
+ for staker in stakers:
+ times = counter[staker]
+ sampled_ratio = times / total_times
+ abs_error = abs(expected - sampled_ratio)
+ assert abs_error < ERROR_TOLERANCE
diff --git a/tests/characters/test_ursula_prepares_to_act_as_mining_node.py b/tests/characters/test_ursula_prepares_to_act_as_mining_node.py
index cac5a22f2..4b6316345 100644
--- a/tests/characters/test_ursula_prepares_to_act_as_mining_node.py
+++ b/tests/characters/test_ursula_prepares_to_act_as_mining_node.py
@@ -164,12 +164,11 @@ def test_blockchain_ursulas_reencrypt(blockchain_ursulas, blockchain_alice, bloc
label = b'bbo'
- # TODO: Investigate issues with wiggle room and additional ursulas during sampling. See also #1061 and #1090
- # 1 <= N <= 4 : OK, although for N=4 it can fail with very small probability (<1%)
- # M = N = 5: Fails with prob. ~66% --> Cannot create policy with 5 arrangements: Selection failed after 5 attempts
- # N == 6 : NotEnoughBlockchainUrsulas: Cannot create policy with 6 arrangements: Selection failed after 5 attempts
- # N >= 7 : NotEnoughBlockchainUrsulas: Cannot create policy with 7 arrangements: Cannot create policy with 7 arrangements: 10 stakers are available, need 11 (for wiggle room)
- m = n = 3
+ # TODO: Make sample selection buffer configurable - #1061
+ # Currently, it only supports N<=6, since for N=7, it tries to sample 11 ursulas due to wiggle room,
+ # and blockchain_ursulas only contains 10.
+ # For N >= 7 : NotEnoughBlockchainUrsulas: Cannot create policy with 7 arrangements: There are 10 active stakers, need at least 11.
+ m = n = 6
expiration = maya.now() + datetime.timedelta(days=5)
_policy = blockchain_alice.grant(bob=blockchain_bob,