Generalize NodeEngagementMutex

pull/2482/head
Bogdan Opanchuk 2021-01-02 23:37:17 -08:00
parent a2b99daa1d
commit bbc4390f68
10 changed files with 712 additions and 168 deletions

View File

@ -102,7 +102,7 @@ policy = ALICE.grant(BOB,
expiration=policy_end_datetime)
assert policy.public_key == policy_pubkey
policy.publishing_mutex.block_until_complete()
policy.treasure_map_publisher.block_until_complete()
# Alice puts her public key somewhere for Bob to find later...
alices_pubkey_bytes_saved_for_posterity = bytes(ALICE.stamp)

View File

@ -141,7 +141,7 @@ policy = alicia.grant(bob=doctor_strange,
m=m,
n=n,
expiration=policy_end_datetime)
policy.publishing_mutex.block_until_complete()
policy.treasure_map_publisher.block_until_complete()
print("Done!")
# For the demo, we need a way to share with Bob some additional info

View File

@ -123,7 +123,7 @@ class AliceInterface(CharacterPublicInterface):
expiration=expiration,
discover_on_this_thread=True)
new_policy.publishing_mutex.block_until_success_is_reasonably_likely()
new_policy.treasure_map_publisher.block_until_success_is_reasonably_likely()
response_data = {'treasure_map': new_policy.treasure_map,
'policy_encrypting_key': new_policy.public_key,

View File

@ -351,7 +351,7 @@ class Alice(Character, BlockchainPolicyAuthor):
self.add_active_policy(enacted_policy)
if publish_treasure_map and block_until_success_is_reasonably_likely:
enacted_policy.publishing_mutex.block_until_success_is_reasonably_likely()
enacted_policy.treasure_map_publisher.block_until_success_is_reasonably_likely()
return enacted_policy
def get_policy_encrypting_key_from_label(self, label: bytes) -> UmbralPublicKey:

View File

@ -47,6 +47,7 @@ from nucypher.crypto.powers import DecryptingPower, SigningPower, TransactingPow
from nucypher.crypto.utils import construct_policy_id
from nucypher.network.exceptions import NodeSeemsToBeDown
from nucypher.network.middleware import RestMiddleware
from nucypher.utilities.concurrency import WorkerPool, AllAtOnceFactory
from nucypher.utilities.logging import Logger
@ -90,151 +91,48 @@ class Arrangement:
return f"Arrangement(client_key={self.alice_verifying_key})"
class NodeEngagementMutex:
"""
TODO: Does this belong on middleware?
class TreasureMapPublisher:
TODO: There are a couple of ways this can break. If one fo the jobs hangs, the whole thing will hang. Also,
if there are fewer successfully completed than percent_to_complete_before_release, the partial queue will never
release.
TODO: Make registry per... I guess Policy? It's weird to be able to accidentally enact again.
"""
log = Logger("Policy")
log = Logger('TreasureMapPublisher')
def __init__(self,
callable_to_engage, # TODO: typing.Protocol
worker,
nodes,
network_middleware,
percent_to_complete_before_release=5,
note=None,
threadpool_size=120,
timeout=20,
*args,
**kwargs):
self.f = callable_to_engage
self.nodes = nodes
self.network_middleware = network_middleware
self.args = args
self.kwargs = kwargs
timeout=20):
self.completed = {}
self.failed = {}
self._total = len(nodes)
self._block_until_this_many_are_complete = math.ceil(len(nodes) * percent_to_complete_before_release / 100)
self._worker_pool = WorkerPool(worker=worker,
value_factory=AllAtOnceFactory(nodes),
target_successes=self._block_until_this_many_are_complete,
timeout=timeout,
stagger_timeout=0,
threadpool_size=threadpool_size)
self._started = False
self._finished = False
self.timeout = timeout
self.percent_to_complete_before_release = percent_to_complete_before_release
self._partial_queue = Queue()
self._completion_queue = Queue()
self._block_until_this_many_are_complete = math.ceil(
len(nodes) * self.percent_to_complete_before_release / 100)
self.nodes_contacted_during_partial_block = False
self.when_complete = Deferred() # TODO: Allow cancelling via KB Interrupt or some other way?
if note is None:
self._repr = f"{callable_to_engage} to {len(nodes)} nodes"
else:
self._repr = f"{note}: {callable_to_engage} to {len(nodes)} nodes"
self._threadpool = ThreadPool(minthreads=threadpool_size, maxthreads=threadpool_size, name=self._repr)
self.log.info(f"NEM spinning up {self._threadpool}")
self._threadpool.callInThread(self._bail_on_timeout)
def __repr__(self):
return self._repr
def _bail_on_timeout(self):
while True:
if self.when_complete.called:
return
duration = datetime.datetime.now() - self._started
if duration.seconds >= self.timeout:
try:
self._threadpool.stop()
except AlreadyQuit:
raise RuntimeError("Is there a race condition here? If this line is being hit, it's a bug.")
raise RuntimeError(f"Timed out. Nodes completed: {self.completed}")
time.sleep(.5)
def block_until_success_is_reasonably_likely(self):
"""
https://www.youtube.com/watch?v=OkSLswPSq2o
"""
if len(self.completed) < self._block_until_this_many_are_complete:
try:
completed_for_reasonable_likelihood_of_success = self._partial_queue.get(timeout=self.timeout) # TODO: Shorter timeout here?
except Empty:
raise RuntimeError(f"Timed out. Nodes completed: {self.completed}")
self.log.debug(f"{len(self.completed)} nodes were contacted while blocking for a little while.")
return completed_for_reasonable_likelihood_of_success
else:
return self.completed
def block_until_complete(self):
if self.total_disposed() < len(self.nodes):
try:
_ = self._completion_queue.get(timeout=self.timeout) # Interesting opportuntiy to pass some data, like the list of contacted nodes above.
except Empty:
raise RuntimeError(f"Timed out. Nodes completed: {self.completed}")
if not reactor.running and not self._threadpool.joined:
# If the reactor isn't running, the user *must* call this, because this is where we stop.
self._threadpool.stop()
def _handle_success(self, response, node):
if response.status_code == 201:
self.completed[node] = response
else:
assert False # TODO: What happens if this is a 300 or 400 level response? (A 500 response will propagate as an error and be handled in the errback chain.)
if self.nodes_contacted_during_partial_block:
self._consider_finalizing()
else:
if len(self.completed) >= self._block_until_this_many_are_complete:
contacted = tuple(self.completed.keys())
self.nodes_contacted_during_partial_block = contacted
self.log.debug(f"Blocked for a little while, completed {contacted} nodes")
self._partial_queue.put(contacted)
return response
def _handle_error(self, failure, node):
self.failed[node] = failure # TODO: Add a failfast mode?
self._consider_finalizing()
self.log.warn(f"{node} failed: {failure}")
def total_disposed(self):
return len(self.completed) + len(self.failed)
def _consider_finalizing(self):
if not self._finished:
if self.total_disposed() == len(self.nodes):
# TODO: Consider whether this can possibly hang.
self._finished = True
if reactor.running:
reactor.callInThread(self._threadpool.stop)
self._completion_queue.put(self.completed)
self.when_complete.callback(self.completed)
self.log.info(f"{self} finished.")
else:
raise RuntimeError("Already finished.")
def _engage_node(self, node):
maybe_coro = self.f(node, network_middleware=self.network_middleware, *self.args, **self.kwargs)
d = ensureDeferred(maybe_coro)
d.addCallback(self._handle_success, node)
d.addErrback(self._handle_error, node)
return d
@property
def completed(self):
# TODO: lock dict before copying?
return self._worker_pool.get_successes()
def start(self):
if self._started:
raise RuntimeError("Already started.")
self._started = datetime.datetime.now()
self.log.info(f"NEM Starting {self._threadpool}")
for node in self.nodes:
self._threadpool.callInThread(self._engage_node, node)
self._threadpool.start()
self.log.info(f"TreasureMapPublisher starting")
self._worker_pool.start()
if reactor.running:
reactor.callInThread(self.block_until_complete)
def block_until_success_is_reasonably_likely(self):
# Note: `OutOfValues`/`TimedOut` may be raised here, which means we didn't even get to
# `percent_to_complete_before_release` successes. For now just letting it fire.
self._worker_pool.block_until_target_successes()
completed = self.completed
self.log.debug(f"The minimal amount of nodes ({len(completed)}) was contacted "
"while blocking for treasure map publication.")
return completed
def block_until_complete(self):
self._worker_pool.join()
class MergedReservoir:
@ -534,23 +432,34 @@ class Policy(ABC):
return treasure_map
def _make_publishing_mutex(self,
treasure_map: 'TreasureMap',
network_middleware: RestMiddleware,
) -> NodeEngagementMutex:
async def put_treasure_map_on_node(node, network_middleware):
response = network_middleware.put_treasure_map_on_node(node=node,
map_payload=bytes(treasure_map))
return response
def _make_publisher(self,
treasure_map: 'TreasureMap',
network_middleware: RestMiddleware,
) -> TreasureMapPublisher:
# TODO (#2516): remove hardcoding of 8 nodes
self.alice.block_until_number_of_known_nodes_is(8, timeout=2, learn_on_this_thread=True)
target_nodes = self.bob.matching_nodes_among(self.alice.known_nodes)
treasure_map_bytes = bytes(treasure_map) # prevent the closure from holding the reference
return NodeEngagementMutex(callable_to_engage=put_treasure_map_on_node,
nodes=target_nodes,
network_middleware=network_middleware)
def put_treasure_map_on_node(node):
try:
response = network_middleware.put_treasure_map_on_node(node=node,
map_payload=treasure_map_bytes)
except Exception as e:
self.log.warn(f"Putting treasure map on {node} failed: {e}")
raise
if response.status_code == 201:
return response
else:
message = f"Putting treasure map on {node} failed with response status: {response.status}"
self.log.warn(message)
# TODO: What happens if this is a 300 or 400 level response?
raise Exception(message)
return TreasureMapPublisher(worker=put_treasure_map_on_node,
nodes=target_nodes)
def enact(self,
network_middleware: RestMiddleware,
@ -572,8 +481,8 @@ class Policy(ABC):
treasure_map = self._make_treasure_map(network_middleware=network_middleware,
arrangements=arrangements)
publishing_mutex = self._make_publishing_mutex(treasure_map=treasure_map,
network_middleware=network_middleware)
treasure_map_publisher = self._make_publisher(treasure_map=treasure_map,
network_middleware=network_middleware)
revocation_kit = RevocationKit(treasure_map, self.alice.stamp)
enacted_policy = EnactedPolicy(self._id,
@ -581,7 +490,7 @@ class Policy(ABC):
self.label,
self.public_key,
treasure_map,
publishing_mutex,
treasure_map_publisher,
revocation_kit,
self.alice.stamp)
@ -730,7 +639,7 @@ class BlockchainPolicy(Policy):
def _enact_arrangements(self,
network_middleware,
arrangements,
publish_treasure_map=True) -> NodeEngagementMutex:
publish_treasure_map=True) -> TreasureMapPublisher:
transaction = self._publish_to_blockchain(list(arrangements))
return super()._enact_arrangements(network_middleware=network_middleware,
arrangements=arrangements,
@ -756,7 +665,7 @@ class EnactedPolicy:
label: bytes,
public_key: UmbralPublicKey,
treasure_map: 'TreasureMap',
publishing_mutex: NodeEngagementMutex,
treasure_map_publisher: TreasureMapPublisher,
revocation_kit: RevocationKit,
alice_verifying_key: UmbralPublicKey,
):
@ -766,10 +675,10 @@ class EnactedPolicy:
self.label = label
self.public_key = public_key
self.treasure_map = treasure_map
self.publishing_mutex = publishing_mutex
self.treasure_map_publisher = treasure_map_publisher
self.revocation_kit = revocation_kit
self.n = len(self.treasure_map.destinations)
self.alice_verifying_key = alice_verifying_key
def publish_treasure_map(self):
self.publishing_mutex.start()
self.treasure_map_publisher.start()

View File

@ -0,0 +1,304 @@
"""
This file is part of nucypher.
nucypher is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
nucypher is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""
import time
from queue import Queue, Empty
from threading import Thread, Event, Lock, Timer, get_ident
from typing import Callable, List, Any, Optional, Dict
from constant_sorrow.constants import PRODUCER_STOPPED, TIMEOUT_TRIGGERED
from twisted._threads import AlreadyQuit
from twisted.python.threadpool import ThreadPool
class Success:
def __init__(self, value, result):
self.value = value
self.result = result
class Failure:
def __init__(self, value, exception):
self.value = value
self.exception = exception
class Cancelled(Exception):
pass
class SetOnce:
"""
A convenience wrapper for a value that can be set once (which can be waited on),
and cannot be overwritten (unless cleared).
"""
def __init__(self):
self._lock = Lock()
self._set_event = Event()
self._value = None
def set(self, value):
with self._lock:
if not self._set_event.is_set():
self._value = value
self._set_event.set()
def is_set(self):
return self._set_event.is_set()
def get_and_clear(self):
with self._lock:
value = self._value
self._value = None
self._set_event.clear()
return value
def get(self):
self._set_event.wait()
return self._value
class WorkerPool:
"""
A generalized class that can start multiple workers in a thread pool with values
drawn from the given value factory object,
and wait for their completion and a given number of successes
(a worker returning something without throwing an exception).
"""
class TimedOut(Exception):
"Raised if waiting for the target number of successes timed out."
class OutOfValues(Exception):
"Raised if the value factory is out of values, but the target number was not reached."
def __init__(self,
worker: Callable[[Any], Any],
value_factory: Callable[[int], Optional[List[Any]]],
target_successes,
timeout: float,
stagger_timeout: float = 0,
threadpool_size: int = None):
# TODO: make stagger_timeout a part of the value factory?
self._worker = worker
self._value_factory = value_factory
self._timeout = timeout
self._stagger_timeout = stagger_timeout
self._target_successes = target_successes
thread_pool_kwargs = {}
if threadpool_size is not None:
thread_pool_kwargs['minthreads'] = threadpool_size
thread_pool_kwargs['maxthreads'] = threadpool_size
self._threadpool = ThreadPool(**thread_pool_kwargs)
# These three tasks must be run in separate threads
# to avoid being blocked by workers in the thread pool.
self._bail_on_timeout_thread = Thread(target=self._bail_on_timeout)
self._produce_values_thread = Thread(target=self._produce_values)
self._process_results_thread = Thread(target=self._process_results)
self._successes = {}
self._failures = {}
self._started_tasks = 0
self._finished_tasks = 0
self._cancel_event = Event()
self._result_queue = Queue()
self._target_value = SetOnce()
self._unexpected_error = SetOnce()
self._results_lock = Lock()
self._stopped = False
def start(self):
# TODO: check if already started?
self._threadpool.start()
self._produce_values_thread.start()
self._process_results_thread.start()
self._bail_on_timeout_thread.start()
def cancel(self):
"""
Cancels the tasks enqueued in the thread pool and stops the producer thread.
"""
self._cancel_event.set()
def join(self):
"""
Waits for all the threads to finish.
Can be called several times.
"""
if self._stopped:
return # or raise AlreadyStopped?
self._produce_values_thread.join()
self._process_results_thread.join()
self._bail_on_timeout_thread.join()
# protect from a possible race
try:
self._threadpool.stop()
except AlreadyQuit:
pass
self._stopped = True
if self._unexpected_error.is_set():
e = self._unexpected_error.get()
raise RuntimeError(f"Unexpected error in the producer thread: {e}")
def _sleep(self, timeout):
"""
Sleeps for a given timeout, can be interrupted by a cancellation event.
"""
if self._cancel_event.wait(timeout):
raise Cancelled
def block_until_target_successes(self) -> Dict:
"""
Blocks until the target number of successes is reached.
Returns a dictionary of values matched to results.
Can be called several times.
"""
if self._unexpected_error.is_set():
# So that we don't raise it again when join() is called
e = self._unexpected_error.get_and_clear()
raise RuntimeError(f"Unexpected error in the producer thread: {e}")
result = self._target_value.get()
if result == TIMEOUT_TRIGGERED:
raise self.TimedOut()
elif result == PRODUCER_STOPPED:
raise self.OutOfValues()
return result
def get_failures(self) -> Dict:
"""
Get the current failures, as a dictionary of values to thrown exceptions.
"""
with self._results_lock:
return dict(self._failures)
def get_successes(self) -> Dict:
"""
Get the current successes, as a dictionary of values to worker return values.
"""
with self._results_lock:
return dict(self._successes)
def _bail_on_timeout(self):
"""
A service thread that cancels the pool on timeout.
"""
if not self._cancel_event.wait(timeout=self._timeout):
self._target_value.set(TIMEOUT_TRIGGERED)
self._cancel_event.set()
def _worker_wrapper(self, value):
"""
A wrapper that catches exceptions thrown by the worker
and sends the results to the processing thread.
"""
try:
# If we're in the cancelled state, interrupt early
self._sleep(0)
result = self._worker(value)
self._result_queue.put(Success(value, result))
except Cancelled as e:
self._result_queue.put(e)
except BaseException as e:
self._result_queue.put(Failure(value, str(e)))
def _process_results(self):
"""
A service thread that processes worker results
and waits for the target number of successes to be reached.
"""
producer_stopped = False
success_event_reached = False
while True:
result = self._result_queue.get()
if result == PRODUCER_STOPPED:
producer_stopped = True
else:
self._finished_tasks += 1
if isinstance(result, Success):
with self._results_lock:
self._successes[result.value] = result.result
len_successes = len(self._successes)
if not success_event_reached and len_successes == self._target_successes:
# A protection for the case of repeating values.
# Only trigger the target value once.
success_event_reached = True
self._target_value.set(self.get_successes())
if isinstance(result, Failure):
with self._results_lock:
self._failures[result.value] = result.exception
if producer_stopped and self._finished_tasks == self._started_tasks:
self.cancel() # to cancel the timeout thread
self._target_value.set(PRODUCER_STOPPED)
break
def _produce_values(self):
while True:
try:
with self._results_lock:
len_successes = len(self._successes)
batch = self._value_factory(len_successes)
if not batch:
break
self._started_tasks += len(batch)
for value in batch:
# There is a possible race between `callInThread()` and `stop()`,
# But we never execute them at the same time,
# because `join()` checks that the producer thread is stopped.
self._threadpool.callInThread(self._worker_wrapper, value)
self._sleep(self._stagger_timeout)
except Cancelled:
break
except BaseException as e:
self._unexpected_error.set(e)
self.cancel()
break
self._result_queue.put(PRODUCER_STOPPED)
class AllAtOnceFactory:
"""
A simple value factory that returns all its values in a single batch.
"""
def __init__(self, values):
self.values = values
self._produced = False
def __call__(self, _successes):
if self._produced:
return None
else:
self._produced = True
return self.values

View File

@ -168,8 +168,8 @@ def test_treasure_map_cannot_be_duplicated(blockchain_ursulas,
expiration=policy_end_datetime)
matching_ursulas = blockchain_bob.matching_nodes_among(blockchain_ursulas)
completed_ursulas = policy.publishing_mutex.block_until_success_is_reasonably_likely()
# Ursulas in publishing_mutex are not real Ursulas, but just some metadata of remote ones.
completed_ursulas = policy.treasure_map_publisher.block_until_success_is_reasonably_likely()
# Ursulas in `treasure_map_publisher` are not real Ursulas, but just some metadata of remote ones.
# We need a real one to access its datastore.
first_completed_ursula = [ursula for ursula in matching_ursulas if ursula in completed_ursulas][0]

View File

@ -0,0 +1,331 @@
import random
import time
from typing import Iterable, Tuple, List, Callable
import pytest
from nucypher.utilities.concurrency import WorkerPool, AllAtOnceFactory
@pytest.fixture(scope='function')
def join_worker_pool(request):
"""
Makes sure the pool is properly joined at the end of the test,
so that one doesn't have to wrap the whole test in a try-finally block.
"""
pool_to_join = None
def register(pool):
nonlocal pool_to_join
pool_to_join = pool
yield register
pool_to_join.join()
class WorkerRule:
def __init__(self, fails: bool = False, timeout_min: float = 0, timeout_max: float = 0):
self.fails = fails
self.timeout_min = timeout_min
self.timeout_max = timeout_max
class WorkerOutcome:
def __init__(self, fails: bool, timeout: float):
self.fails = fails
self.timeout = timeout
def __call__(self, value):
time.sleep(self.timeout)
if self.fails:
raise Exception(f"Worker for {value} failed")
else:
return value
def generate_workers(rules: Iterable[Tuple[WorkerRule, int]], seed=None):
rng = random.Random(seed)
outcomes = []
for rule, quantity in rules:
for _ in range(quantity):
timeout = rng.uniform(rule.timeout_min, rule.timeout_max)
outcomes.append(WorkerOutcome(rule.fails, timeout))
rng.shuffle(outcomes)
values = list(range(len(outcomes)))
def worker(value):
return outcomes[value](value)
return {value: outcomes[value] for value in values}, worker
def test_wait_for_successes(join_worker_pool):
"""
Checks that `block_until_target_successes()` returns in time and gives all the successes,
if there were enough of them.
"""
outcomes, worker = generate_workers(
[
(WorkerRule(timeout_min=0.5, timeout_max=1.5), 10),
(WorkerRule(fails=True, timeout_min=1, timeout_max=3), 20),
],
seed=123)
factory = AllAtOnceFactory(list(outcomes))
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=30)
join_worker_pool(pool)
t_start = time.monotonic()
pool.start()
successes = pool.block_until_target_successes()
t_end = time.monotonic()
failures = pool.get_failures()
assert all(outcomes[value].fails for value in failures)
assert len(successes) == 10
# We have more threads in the pool than the workers,
# so all the successful ones should be able to finish right away.
assert t_end - t_start < 2
# Should be able to do it several times
successes = pool.block_until_target_successes()
assert len(successes) == 10
def test_wait_for_successes_out_of_values(join_worker_pool):
"""
Checks that if there weren't enough successful workers, `block_until_target_successes()`
raises an exception when the value factory is exhausted.
"""
outcomes, worker = generate_workers(
[
(WorkerRule(timeout_min=0.5, timeout_max=1.5), 9),
(WorkerRule(fails=True, timeout_min=0.5, timeout_max=1.5), 20),
],
seed=123)
factory = AllAtOnceFactory(list(outcomes))
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=15)
join_worker_pool(pool)
t_start = time.monotonic()
pool.start()
with pytest.raises(WorkerPool.OutOfValues):
successes = pool.block_until_target_successes()
t_end = time.monotonic()
# We have roughly 2 workers per thread, so it shouldn't take longer than 1.5s (max timeout) * 2
assert t_end - t_start < 4
def test_wait_for_successes_timed_out(join_worker_pool):
"""
Checks that if enough successful workers can't finish before the timeout, we get an exception.
"""
outcomes, worker = generate_workers(
[
(WorkerRule(timeout_min=0, timeout_max=0.5), 9),
(WorkerRule(timeout_min=1.5, timeout_max=2.5), 1),
(WorkerRule(fails=True, timeout_min=1.5, timeout_max=2.5), 20),
],
seed=123)
factory = AllAtOnceFactory(list(outcomes))
pool = WorkerPool(worker, factory, target_successes=10, timeout=1, threadpool_size=30)
join_worker_pool(pool)
t_start = time.monotonic()
pool.start()
with pytest.raises(WorkerPool.TimedOut):
successes = pool.block_until_target_successes()
t_end = time.monotonic()
# Even though timeout is 1, there are long-running workers which we can't interupt.
assert t_end - t_start < 3
def test_join(join_worker_pool):
"""
Test joining the pool.
"""
outcomes, worker = generate_workers(
[
(WorkerRule(timeout_min=0.5, timeout_max=1.5), 9),
(WorkerRule(fails=True, timeout_min=0.5, timeout_max=1.5), 20),
],
seed=123)
factory = AllAtOnceFactory(list(outcomes))
pool = WorkerPool(worker, factory, target_successes=10, timeout=1, threadpool_size=30)
join_worker_pool(pool)
t_start = time.monotonic()
pool.start()
pool.join()
t_end = time.monotonic()
pool.join() # should work the second time too
# Even though timeout is 1, there are long-running workers which we can't interupt.
assert t_end - t_start < 3
class BatchFactory:
def __init__(self, values):
self.values = values
self.batch_sizes = []
def __call__(self, successes):
if successes == 10:
return None
batch_size = 10 - successes
if len(self.values) >= batch_size:
batch = self.values[:batch_size]
self.batch_sizes.append(len(batch))
self.values = self.values[batch_size:]
return batch
elif len(self.values) > 0:
self.batch_sizes.append(len(self.values))
return self.values
self.values = None
else:
return None
def test_batched_value_generation(join_worker_pool):
"""
Tests a value factory that gives out value batches in portions.
"""
outcomes, worker = generate_workers(
[
(WorkerRule(timeout_min=0.5, timeout_max=1.5), 80),
(WorkerRule(fails=True, timeout_min=0.5, timeout_max=1.5), 80),
],
seed=123)
factory = BatchFactory(list(outcomes))
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10, stagger_timeout=0.5)
join_worker_pool(pool)
t_start = time.monotonic()
pool.start()
successes = pool.block_until_target_successes()
pool.cancel()
pool.join()
t_end = time.monotonic()
assert len(successes) == 10
# Check that batch sizes in the factory were getting progressively smaller
# as the number of successes grew.
assert all(factory.batch_sizes[i] >= factory.batch_sizes[i+1]
for i in range(len(factory.batch_sizes) - 1))
# Since we canceled the pool, no more workers will be started and we will finish faster
assert t_end - t_start < 4
successes_copy = pool.get_successes()
failures_copy = pool.get_failures()
assert all(value in successes_copy for value in successes)
def test_cancel_waiting_workers(join_worker_pool):
"""
If we have a small pool and many workers, it is possible for workers to be enqueued
one after another in one thread.
We test that if we call `cancel()`, these enqueued workers are cancelled too.
"""
outcomes, worker = generate_workers(
[
(WorkerRule(timeout_min=1, timeout_max=1), 100),
],
seed=123)
factory = AllAtOnceFactory(list(outcomes))
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10)
join_worker_pool(pool)
t_start = time.monotonic()
pool.start()
pool.block_until_target_successes()
pool.cancel()
pool.join()
t_end = time.monotonic()
# We have 10 threads in the pool and 100 workers that are all enqueued at once at the start.
# If we didn't check for the cancel condition, we would have to wait for 10 seconds.
# We get 10 successes after 1s and cancel the workers,
# but the next workers in each thread have already started, so we have to wait for another 1s.
assert t_end - t_start < 2.5
class BuggyFactory:
def __init__(self, values):
self.values = values
def __call__(self, successes):
if self.values is not None:
values = self.values
self.values = None
return values
else:
raise Exception("Buggy factory")
def test_buggy_factory_raises_on_block(join_worker_pool):
"""
Tests that if there is an exception thrown in the value factory,
it is caught in the first call to `block_until_target_successes()`.
"""
outcomes, worker = generate_workers(
[(WorkerRule(timeout_min=1, timeout_max=1), 100)],
seed=123)
factory = BuggyFactory(list(outcomes))
# Non-zero stagger timeout to make BuggyFactory raise its error only in 1.5s,
# So that we got enough successes for `block_until_target_successes()`.
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10, stagger_timeout=1.5)
join_worker_pool(pool)
pool.start()
time.sleep(2) # wait for the stagger timeout to finish
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
pool.block_until_target_successes()
# Further calls to `block_until_target_successes()` or `join()` don't throw the error.
pool.block_until_target_successes()
pool.cancel()
pool.join()
def test_buggy_factory_raises_on_join(join_worker_pool):
"""
Tests that if there is an exception thrown in the value factory,
it is caught in the first call to `join()`.
"""
outcomes, worker = generate_workers(
[(WorkerRule(timeout_min=1, timeout_max=1), 100)],
seed=123)
factory = BuggyFactory(list(outcomes))
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10)
join_worker_pool(pool)
pool.start()
pool.cancel()
with pytest.raises(RuntimeError, match="Unexpected error in the producer thread"):
pool.join()
pool.join()

View File

@ -241,7 +241,7 @@ def enacted_federated_policy(idle_federated_policy, federated_ursulas):
# REST call happens here, as does population of TreasureMap.
enacted_policy = idle_federated_policy.enact(network_middleware=network_middleware,
handpicked_ursulas=federated_ursulas)
enacted_policy.publishing_mutex.block_until_complete()
enacted_policy.treasure_map_publisher.block_until_complete()
return enacted_policy
@ -278,7 +278,7 @@ def enacted_blockchain_policy(idle_blockchain_policy, blockchain_ursulas):
# REST call happens here, as does population of TreasureMap.
enacted_policy = idle_blockchain_policy.enact(network_middleware=network_middleware,
handpicked_ursulas=list(blockchain_ursulas))
enacted_policy.publishing_mutex.block_until_complete()
enacted_policy.treasure_map_publisher.block_until_complete()
return enacted_policy

View File

@ -140,7 +140,7 @@ def test_alice_verifies_ursula_just_in_time(fleet_of_highperf_mocked_ursulas,
_POLICY_PRESERVER.append(policy)
# @pytest_twisted.inlineCallbacks # TODO: Why does this, in concert with yield policy.publishing_mutex.when_complete, hang?
# @pytest_twisted.inlineCallbacks # TODO: Why does this, in concert with yield policy.treasure_map_publisher.when_complete, hang?
def test_mass_treasure_map_placement(fleet_of_highperf_mocked_ursulas,
highperf_mocked_alice,
highperf_mocked_bob):
@ -186,21 +186,21 @@ def test_mass_treasure_map_placement(fleet_of_highperf_mocked_ursulas,
# defer.setDebugging(True)
# PART II: We block for a little while to ensure that the distribution is going well.
nodes_that_have_the_map_when_we_unblock = policy.publishing_mutex.block_until_success_is_reasonably_likely()
nodes_that_have_the_map_when_we_unblock = policy.treasure_map_publisher.block_until_success_is_reasonably_likely()
little_while_ended_at = datetime.now()
# The number of nodes having the map is at least the minimum to have unblocked.
assert len(nodes_that_have_the_map_when_we_unblock) >= policy.publishing_mutex._block_until_this_many_are_complete
assert len(nodes_that_have_the_map_when_we_unblock) >= policy.treasure_map_publisher._block_until_this_many_are_complete
# The number of nodes having the map is approximately the number you'd expect from full utilization of Alice's publication threadpool.
# TODO: This line fails sometimes because the loop goes too fast.
# assert len(nodes_that_have_the_map_when_we_unblock) == pytest.approx(policy.publishing_mutex._block_until_this_many_are_complete, .2)
# assert len(nodes_that_have_the_map_when_we_unblock) == pytest.approx(policy.treasure_map_publisher._block_until_this_many_are_complete, .2)
# PART III: Having made proper assertions about the publication call and the first block, we allow the rest to
# happen in the background and then ensure that each phase was timely.
# This will block until the distribution is complete.
policy.publishing_mutex.block_until_complete()
policy.treasure_map_publisher.block_until_complete()
complete_distribution_time = datetime.now() - started
partial_blocking_duration = little_while_ended_at - started
# Before Treasure Island (1741), this process took about 3 minutes.
@ -213,7 +213,7 @@ def test_mass_treasure_map_placement(fleet_of_highperf_mocked_ursulas,
# But with debuggers and other processes running on laptops, we give a little leeway.
# We have the same number of successful responses as nodes we expected to have the map.
assert len(policy.publishing_mutex.completed) == len(nodes_we_expect_to_have_the_map)
assert len(policy.treasure_map_publisher.completed) == len(nodes_we_expect_to_have_the_map)
nodes_that_got_the_map = sum(
u._its_down_there_somewhere_let_me_take_another_look is True for u in nodes_we_expect_to_have_the_map)
assert nodes_that_got_the_map == len(nodes_we_expect_to_have_the_map)