mirror of https://github.com/nucypher/nucypher.git
Make WeightedSampler behave correctly in case of several consecutive draws
parent
13a6ab8375
commit
e81f285517
|
@ -1763,6 +1763,7 @@ class WeightedSampler:
|
|||
elements, weights = zip(*weighted_elements.items())
|
||||
self.totals = list(accumulate(weights))
|
||||
self.elements = elements
|
||||
self.__length = len(self.totals)
|
||||
|
||||
def sample_no_replacement(self, rng, quantity: int) -> list:
|
||||
"""
|
||||
|
@ -1780,25 +1781,26 @@ class WeightedSampler:
|
|||
if quantity > len(self):
|
||||
raise ValueError("Cannot sample more than the total amount of elements without replacement")
|
||||
|
||||
totals = self.totals.copy()
|
||||
samples = []
|
||||
|
||||
for i in range(quantity):
|
||||
position = rng.randint(0, totals[-1] - 1)
|
||||
idx = bisect_right(totals, position)
|
||||
position = rng.randint(0, self.totals[-1] - 1)
|
||||
idx = bisect_right(self.totals, position)
|
||||
samples.append(self.elements[idx])
|
||||
|
||||
# Adjust the totals so that they correspond
|
||||
# to the weight of the element `idx` being set to 0.
|
||||
prev_total = totals[idx - 1] if idx > 0 else 0
|
||||
weight = totals[idx] - prev_total
|
||||
for j in range(idx, len(totals)):
|
||||
totals[j] -= weight
|
||||
prev_total = self.totals[idx - 1] if idx > 0 else 0
|
||||
weight = self.totals[idx] - prev_total
|
||||
for j in range(idx, len(self.totals)):
|
||||
self.totals[j] -= weight
|
||||
|
||||
self.__length -= quantity
|
||||
|
||||
return samples
|
||||
|
||||
def __len__(self):
|
||||
return len(self.totals)
|
||||
return self.__length
|
||||
|
||||
|
||||
class StakersReservoir:
|
||||
|
|
|
@ -166,8 +166,8 @@ def test_weighted_sampler(sample_size):
|
|||
weighted_elements = {element: weight for element, weight in zip(elements, weights)}
|
||||
|
||||
samples = 100000
|
||||
sampler = WeightedSampler(weighted_elements)
|
||||
for i in range(samples):
|
||||
sampler = WeightedSampler(weighted_elements)
|
||||
sample_set = sampler.sample_no_replacement(rng, sample_size)
|
||||
counter.update({tuple(sample_set): 1})
|
||||
|
||||
|
|
Loading…
Reference in New Issue