Make WeightedSampler behave correctly in case of several consecutive draws

pull/2482/head
Bogdan Opanchuk 2020-12-17 21:16:03 -08:00
parent 13a6ab8375
commit e81f285517
2 changed files with 11 additions and 9 deletions

View File

@ -1763,6 +1763,7 @@ class WeightedSampler:
elements, weights = zip(*weighted_elements.items())
self.totals = list(accumulate(weights))
self.elements = elements
self.__length = len(self.totals)
def sample_no_replacement(self, rng, quantity: int) -> list:
"""
@ -1780,25 +1781,26 @@ class WeightedSampler:
if quantity > len(self):
raise ValueError("Cannot sample more than the total amount of elements without replacement")
totals = self.totals.copy()
samples = []
for i in range(quantity):
position = rng.randint(0, totals[-1] - 1)
idx = bisect_right(totals, position)
position = rng.randint(0, self.totals[-1] - 1)
idx = bisect_right(self.totals, position)
samples.append(self.elements[idx])
# Adjust the totals so that they correspond
# to the weight of the element `idx` being set to 0.
prev_total = totals[idx - 1] if idx > 0 else 0
weight = totals[idx] - prev_total
for j in range(idx, len(totals)):
totals[j] -= weight
prev_total = self.totals[idx - 1] if idx > 0 else 0
weight = self.totals[idx] - prev_total
for j in range(idx, len(self.totals)):
self.totals[j] -= weight
self.__length -= quantity
return samples
def __len__(self):
return len(self.totals)
return self.__length
class StakersReservoir:

View File

@ -166,8 +166,8 @@ def test_weighted_sampler(sample_size):
weighted_elements = {element: weight for element, weight in zip(elements, weights)}
samples = 100000
sampler = WeightedSampler(weighted_elements)
for i in range(samples):
sampler = WeightedSampler(weighted_elements)
sample_set = sampler.sample_no_replacement(rng, sample_size)
counter.update({tuple(sample_set): 1})