Add tests for BatchValueFactory, and fix issues.

pull/3126/head
derekpierre 2023-05-09 12:55:37 -04:00
parent 3e498a3751
commit bc8cb48c70
2 changed files with 215 additions and 16 deletions

View File

@ -336,19 +336,23 @@ class BatchValueFactory:
def __init__(
self, values: List[Any], required_successes: int, batch_size: int = None
):
self.values = values
self.batch_start_index = 0
if len(self.values) < required_successes:
raise ValueError(
f"Available values ({len(self.values)} less than required successes {required_successes}"
)
if required_successes < 0:
if not values:
raise ValueError(f"No available values provided")
if required_successes <= 0:
raise ValueError(
f"Invalid number of successes required ({required_successes})"
)
self.required_successes = required_successes
if batch_size and batch_size <= 0:
self.values = values
self.required_successes = required_successes
if len(self.values) < self.required_successes:
raise ValueError(
f"Available values ({len(self.values)} less than required successes {self.required_successes}"
)
self._batch_start_index = 0
if batch_size is not None and batch_size <= 0:
raise ValueError(f"Invalid batch size specified ({batch_size})")
self.batch_size = batch_size if batch_size else required_successes
@ -357,18 +361,17 @@ class BatchValueFactory:
# no more work needed to be done
return None
if self.batch_start_index == len(self.values):
if self._batch_start_index == len(self.values):
# no more values to process
return None
batch_size = min(self.required_successes - successes, self.batch_size)
batch_end_index = self.batch_start_index + batch_size
batch_end_index = self._batch_start_index + self.batch_size
if batch_end_index <= len(self.values):
batch = self.values[self.batch_start_index : batch_end_index]
self.batch_start_index = batch_end_index
batch = self.values[self._batch_start_index : batch_end_index]
self._batch_start_index = batch_end_index
return batch
else:
# return all remaining values
batch = self.values[self.batch_start_index :]
self.batch_start_index = len(self.values)
batch = self.values[self._batch_start_index :]
self._batch_start_index = len(self.values)
return batch

View File

@ -0,0 +1,196 @@
import pytest
from nucypher.utilities.concurrency import BatchValueFactory
NUM_VALUES = 20
@pytest.fixture(scope="module")
def values():
values = []
for i in range(0, NUM_VALUES):
values.append(i)
return values
def test_batch_value_factory_invalid_values(values):
with pytest.raises(ValueError):
BatchValueFactory(values=[], required_successes=0)
with pytest.raises(ValueError):
BatchValueFactory(values=[], required_successes=1)
with pytest.raises(ValueError):
BatchValueFactory(values=[1, 2, 3, 4], required_successes=5)
with pytest.raises(ValueError):
BatchValueFactory(values=[1, 2, 3, 4], required_successes=2, batch_size=0)
def test_batch_value_factory_all_successes_no_specified_batching(values):
target_successes = NUM_VALUES
value_factory = BatchValueFactory(
values=values, required_successes=target_successes
)
# number of successes returned since no batching provided
value_list = value_factory(successes=0)
assert len(value_list) == target_successes, "list returned is based on successes"
assert len(values) == NUM_VALUES, "values remained unchanged"
# get list again
value_list = value_factory(successes=NUM_VALUES) # successes achieved
assert not value_list, "successes achieved and no more values available"
# get list again
value_list = value_factory(successes=0) # successes not achieved
assert not value_list, "no successes achieved but no more values available"
def test_batch_value_factory_no_specified_batching_no_more_values_after_target_successes(
values,
):
target_successes = 1
value_factory = BatchValueFactory(
values=values, required_successes=target_successes
)
for i in range(0, NUM_VALUES // 3):
value_list = value_factory(successes=0)
assert (
len(value_list) == target_successes
), "list returned is based on successes"
assert len(values) == NUM_VALUES, "values remained unchanged"
for i in range(NUM_VALUES // 3, NUM_VALUES):
value_list = value_factory(successes=target_successes)
assert (
not value_list
), "there are more values but no more is needed since target successes attained"
def test_batch_value_factory_no_batching_no_success_multiple_calls(values):
target_successes = 4
value_factory = BatchValueFactory(
values=values, required_successes=target_successes
)
for i in range(0, NUM_VALUES // target_successes):
value_list = value_factory(successes=0)
assert (
len(value_list) == target_successes
), "list returned is based on successes"
assert len(values) == NUM_VALUES, "values remained unchanged"
# list all done but get list again
value_list = value_factory(successes=target_successes) # successes achieved
assert not value_list, "successes achieved"
# list all done but get list again
value_list = value_factory(
successes=1
) # not enough successes but list is now empty
assert not value_list, "successes not achieved, but no more values available"
def test_batch_value_factory_no_batching_no_success_multiple_calls_non_divisible_successes(
values,
):
target_successes = 6
value_factory = BatchValueFactory(
values=values, required_successes=target_successes
)
# should be able to get 4 lists
for i in range(0, NUM_VALUES // target_successes):
value_list = value_factory(successes=0)
assert (
len(value_list) == target_successes
), "list returned is based on successes"
assert len(values) == NUM_VALUES, "values remained unchanged"
# last request
value_list = value_factory(successes=0)
assert len(value_list) == NUM_VALUES % target_successes, "remaining list returned"
# get list again
value_list = value_factory(successes=target_successes) # successes achieved
assert not value_list, "successes achieved"
# get list again
value_list = value_factory(
successes=target_successes - 1
) # not enough successes but list is now empty
assert not value_list, "successes not achieved, but no more values available"
def test_batch_value_factory_batching_individual(values):
target_successes = NUM_VALUES
batch_size = 1
value_factory = BatchValueFactory(
values=values, required_successes=target_successes, batch_size=batch_size
)
# number of successes returned since no batching provided
for i in range(0, NUM_VALUES // batch_size):
value_list = value_factory(successes=0)
assert len(value_list) == batch_size, "list returned is based on batch size"
assert len(values) == NUM_VALUES, "values remained unchanged"
# get list again
value_list = value_factory(successes=NUM_VALUES) # successes achieved
assert not value_list, "successes achieved and no more values available"
# get list again
value_list = value_factory(successes=0) # successes not achieved
assert not value_list, "no successes achieved but no more values available"
def test_batch_value_factory_batching_divisible(values):
target_successes = NUM_VALUES
batch_size = 5
value_factory = BatchValueFactory(
values=values, required_successes=target_successes, batch_size=batch_size
)
# number of successes returned since no batching provided (3x here)
for i in range(0, NUM_VALUES // batch_size):
value_list = value_factory(successes=target_successes - 1)
assert len(value_list) == batch_size, "list returned is based on batch size"
assert len(values) == NUM_VALUES, "values remained unchanged"
# get list again
value_list = value_factory(successes=NUM_VALUES) # successes achieved
assert not value_list, "successes achieved and no more values available"
# get list again
value_list = value_factory(successes=0) # successes not achieved
assert not value_list, "no successes achieved but no more values available"
def test_batch_value_factory_batching_non_divisible(values):
target_successes = NUM_VALUES
batch_size = 7
value_factory = BatchValueFactory(
values=values, required_successes=target_successes, batch_size=batch_size
)
# number of successes returned since no batching provided
for i in range(0, NUM_VALUES // batch_size):
value_list = value_factory(successes=0)
assert len(value_list) == batch_size, "list returned is based on batch size"
assert len(values) == NUM_VALUES, "values remained unchanged"
# one more
value_list = value_factory(successes=0)
assert len(value_list) == NUM_VALUES % batch_size, "remainder of list returned"
assert len(values) == NUM_VALUES, "values remained unchanged"
# get list again
value_list = value_factory(successes=target_successes) # successes achieved
assert not value_list, "successes achieved and no more values available"
# get list again
value_list = value_factory(successes=0) # successes not achieved
assert not value_list, "no successes achieved but no more values available"