From e81f2855178ac33215ae127bfd074b3723dda73a Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Thu, 17 Dec 2020 21:16:03 -0800 Subject: [PATCH] Make WeightedSampler behave correctly in case of several consecutive draws --- nucypher/blockchain/eth/agents.py | 18 ++++++++++-------- .../agents/test_sampling_distribution.py | 2 +- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/nucypher/blockchain/eth/agents.py b/nucypher/blockchain/eth/agents.py index e1ebddbae..89faa371b 100644 --- a/nucypher/blockchain/eth/agents.py +++ b/nucypher/blockchain/eth/agents.py @@ -1763,6 +1763,7 @@ class WeightedSampler: elements, weights = zip(*weighted_elements.items()) self.totals = list(accumulate(weights)) self.elements = elements + self.__length = len(self.totals) def sample_no_replacement(self, rng, quantity: int) -> list: """ @@ -1780,25 +1781,26 @@ class WeightedSampler: if quantity > len(self): raise ValueError("Cannot sample more than the total amount of elements without replacement") - totals = self.totals.copy() samples = [] for i in range(quantity): - position = rng.randint(0, totals[-1] - 1) - idx = bisect_right(totals, position) + position = rng.randint(0, self.totals[-1] - 1) + idx = bisect_right(self.totals, position) samples.append(self.elements[idx]) # Adjust the totals so that they correspond # to the weight of the element `idx` being set to 0. - prev_total = totals[idx - 1] if idx > 0 else 0 - weight = totals[idx] - prev_total - for j in range(idx, len(totals)): - totals[j] -= weight + prev_total = self.totals[idx - 1] if idx > 0 else 0 + weight = self.totals[idx] - prev_total + for j in range(idx, len(self.totals)): + self.totals[j] -= weight + + self.__length -= quantity return samples def __len__(self): - return len(self.totals) + return self.__length class StakersReservoir: diff --git a/tests/acceptance/blockchain/agents/test_sampling_distribution.py b/tests/acceptance/blockchain/agents/test_sampling_distribution.py index 7bfac2f43..64d73b2e2 100644 --- a/tests/acceptance/blockchain/agents/test_sampling_distribution.py +++ b/tests/acceptance/blockchain/agents/test_sampling_distribution.py @@ -166,8 +166,8 @@ def test_weighted_sampler(sample_size): weighted_elements = {element: weight for element, weight in zip(elements, weights)} samples = 100000 - sampler = WeightedSampler(weighted_elements) for i in range(samples): + sampler = WeightedSampler(weighted_elements) sample_set = sampler.sample_no_replacement(rng, sample_size) counter.update({tuple(sample_set): 1})