mirror of https://github.com/nucypher/nucypher.git
Merge pull request #3562 from derekpierre/latency-tracking
Learner node latency trackingpull/3563/head
commit
a78ac58819
|
@ -0,0 +1 @@
|
|||
Enhance threshold decryption request efficiency by prioritizing nodes in the cohort with lower communication latency.
|
|
@ -387,6 +387,8 @@ class Alice(Character, actors.PolicyAuthor):
|
|||
|
||||
|
||||
class Bob(Character):
|
||||
_TRACK_NODE_LATENCY_STATS = True
|
||||
|
||||
banner = BOB_BANNER
|
||||
_default_dkg_variant = FerveoVariant.Simple
|
||||
_default_crypto_powerups = [SigningPower, DecryptingPower]
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import math
|
||||
from http import HTTPStatus
|
||||
from random import shuffle
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from eth_typing import ChecksumAddress
|
||||
|
@ -80,15 +79,20 @@ class ThresholdDecryptionClient(ThresholdAccessControlClient):
|
|||
self.log.warn(message)
|
||||
raise self.ThresholdDecryptionRequestFailed(message)
|
||||
|
||||
# TODO: Find a better request order, perhaps based on latency data obtained from discovery loop - #3395
|
||||
requests = list(encrypted_requests)
|
||||
shuffle(requests)
|
||||
ursulas_to_contact = (
|
||||
self._learner.node_latency_collector.order_addresses_by_latency(
|
||||
list(encrypted_requests)
|
||||
)
|
||||
if self._learner.node_latency_collector
|
||||
else list(encrypted_requests)
|
||||
)
|
||||
|
||||
# Discussion about WorkerPool parameters:
|
||||
# "https://github.com/nucypher/nucypher/pull/3393#discussion_r1456307991"
|
||||
worker_pool = WorkerPool(
|
||||
worker=worker,
|
||||
value_factory=self.ThresholdDecryptionRequestFactory(
|
||||
ursulas_to_contact=requests,
|
||||
ursulas_to_contact=ursulas_to_contact,
|
||||
batch_size=math.ceil(threshold * 1.25),
|
||||
threshold=threshold,
|
||||
),
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import time
|
||||
from collections import deque
|
||||
from contextlib import suppress
|
||||
|
@ -40,6 +41,7 @@ from nucypher.crypto.signing import InvalidSignature, SignatureStamp
|
|||
from nucypher.network.exceptions import NodeSeemsToBeDown
|
||||
from nucypher.network.middleware import RestMiddleware
|
||||
from nucypher.network.protocols import InterfaceInfo, SuspiciousActivity
|
||||
from nucypher.utilities.latency import NodeLatencyStatsCollector
|
||||
from nucypher.utilities.logging import Logger
|
||||
|
||||
TEACHER_NODES = {
|
||||
|
@ -218,6 +220,8 @@ class Learner:
|
|||
_ROUNDS_WITHOUT_NODES_AFTER_WHICH_TO_SLOW_DOWN = 10
|
||||
__DEFAULT_MIDDLEWARE_CLASS = RestMiddleware
|
||||
|
||||
_TRACK_NODE_LATENCY_STATS = False
|
||||
|
||||
_crashed = (
|
||||
False # moved from Character - why was this in Character and not Learner before
|
||||
)
|
||||
|
@ -261,6 +265,10 @@ class Learner:
|
|||
self.log = Logger("learning-loop") # type: Logger
|
||||
self.domain = domain
|
||||
|
||||
self.node_latency_collector = (
|
||||
NodeLatencyStatsCollector() if self._TRACK_NODE_LATENCY_STATS else None
|
||||
)
|
||||
|
||||
self.learning_deferred = Deferred()
|
||||
default_middleware = self.__DEFAULT_MIDDLEWARE_CLASS(
|
||||
registry=self.registry, eth_endpoint=self.eth_endpoint
|
||||
|
@ -827,11 +835,19 @@ class Learner:
|
|||
return RELAX
|
||||
|
||||
try:
|
||||
response = self.network_middleware.get_nodes_via_rest(
|
||||
node=current_teacher,
|
||||
announce_nodes=announce_nodes,
|
||||
fleet_state_checksum=self.known_nodes.checksum,
|
||||
optional_latency_context_manager = (
|
||||
self.node_latency_collector.get_latency_tracker(
|
||||
current_teacher.checksum_address
|
||||
)
|
||||
if self.node_latency_collector
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
with optional_latency_context_manager:
|
||||
response = self.network_middleware.get_nodes_via_rest(
|
||||
node=current_teacher,
|
||||
announce_nodes=announce_nodes,
|
||||
fleet_state_checksum=self.known_nodes.checksum,
|
||||
)
|
||||
# These except clauses apply to the current_teacher itself, not the learned-about nodes.
|
||||
except NodeSeemsToBeDown as e:
|
||||
unresponsive_nodes.add(current_teacher)
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
import collections
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from threading import Lock
|
||||
from typing import List
|
||||
|
||||
from eth_typing import ChecksumAddress
|
||||
|
||||
|
||||
class NodeLatencyStatsCollector:
|
||||
"""
|
||||
Thread-safe utility that tracks latency statistics related to P2P connections with other nodes.
|
||||
"""
|
||||
|
||||
MAX_MOVING_AVERAGE_WINDOW = 5
|
||||
MAX_LATENCY = float(2**16) # just need a large number for sorting
|
||||
|
||||
class NodeLatencyContextManager:
|
||||
def __init__(
|
||||
self,
|
||||
stats_collector: "NodeLatencyStatsCollector",
|
||||
staker_address: ChecksumAddress,
|
||||
):
|
||||
self._stats_collector = stats_collector
|
||||
self.staker_address = staker_address
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type:
|
||||
# exception occurred - reset stats since connectivity was compromised
|
||||
self._stats_collector.reset_stats(self.staker_address)
|
||||
else:
|
||||
# no exception
|
||||
end_time = time.perf_counter()
|
||||
execution_time = end_time - self.start_time
|
||||
self._stats_collector._update_stats(self.staker_address, execution_time)
|
||||
|
||||
def __init__(self, max_moving_average_window: int = MAX_MOVING_AVERAGE_WINDOW):
|
||||
self._node_stats = defaultdict(
|
||||
lambda: collections.deque([], maxlen=max_moving_average_window)
|
||||
)
|
||||
self._lock = Lock()
|
||||
|
||||
def _update_stats(self, staking_address: ChecksumAddress, latest_time_taken: float):
|
||||
with self._lock:
|
||||
self._node_stats[staking_address].append(latest_time_taken)
|
||||
|
||||
def reset_stats(self, staking_address: ChecksumAddress):
|
||||
with self._lock:
|
||||
self._node_stats[staking_address].clear()
|
||||
|
||||
def get_latency_tracker(
|
||||
self, staker_address: ChecksumAddress
|
||||
) -> NodeLatencyContextManager:
|
||||
return self.NodeLatencyContextManager(
|
||||
stats_collector=self, staker_address=staker_address
|
||||
)
|
||||
|
||||
def get_average_latency_time(self, staking_address: ChecksumAddress) -> float:
|
||||
with self._lock:
|
||||
readings = list(self._node_stats[staking_address])
|
||||
num_readings = len(readings)
|
||||
# just need a large number > 0
|
||||
return (
|
||||
self.MAX_LATENCY if num_readings == 0 else sum(readings) / num_readings
|
||||
)
|
||||
|
||||
def order_addresses_by_latency(
|
||||
self, staking_addresses: List[ChecksumAddress]
|
||||
) -> List[ChecksumAddress]:
|
||||
result = sorted(
|
||||
staking_addresses,
|
||||
key=lambda staking_address: self.get_average_latency_time(staking_address),
|
||||
)
|
||||
return result
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
import random
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
import pytest
|
||||
import pytest_twisted
|
||||
|
@ -12,6 +12,7 @@ from nucypher.blockchain.eth.constants import NULL_ADDRESS
|
|||
from nucypher.blockchain.eth.models import Coordinator
|
||||
from nucypher.blockchain.eth.signers.software import InMemorySigner
|
||||
from nucypher.characters.lawful import Enrico, Ursula
|
||||
from nucypher.network.decryption import ThresholdDecryptionClient
|
||||
from nucypher.policy.conditions.evm import ContractCondition, RPCCondition
|
||||
from nucypher.policy.conditions.lingo import (
|
||||
ConditionLingo,
|
||||
|
@ -275,19 +276,88 @@ def test_unauthorized_decryption(
|
|||
yield
|
||||
|
||||
|
||||
def check_decrypt_without_any_cached_values(
|
||||
@pytest_twisted.inlineCallbacks
|
||||
def test_authorized_decryption(
|
||||
mocker,
|
||||
bob,
|
||||
global_allow_list,
|
||||
accounts,
|
||||
coordinator_agent,
|
||||
threshold_message_kit,
|
||||
signer,
|
||||
initiator,
|
||||
ritual_id,
|
||||
cohort,
|
||||
plaintext,
|
||||
):
|
||||
print("==================== DKG DECRYPTION (AUTHORIZED) ====================")
|
||||
# authorize Enrico to encrypt for ritual
|
||||
global_allow_list.authorize(
|
||||
ritual_id,
|
||||
[signer.accounts[0]],
|
||||
sender=accounts[initiator.transacting_power.account],
|
||||
)
|
||||
|
||||
# fake some latency stats
|
||||
latency_stats = {}
|
||||
for ursula in cohort:
|
||||
# reset all stats
|
||||
bob.node_latency_collector.reset_stats(ursula.checksum_address)
|
||||
# add a single data point for each ursula: some time between 0.1 and 4
|
||||
mock_latency = random.uniform(0.1, 4)
|
||||
bob.node_latency_collector._update_stats(ursula.checksum_address, mock_latency)
|
||||
latency_stats[ursula.checksum_address] = mock_latency
|
||||
|
||||
expected_ursula_request_ordering = sorted(
|
||||
list(latency_stats.keys()),
|
||||
key=lambda ursula_checksum: latency_stats[ursula_checksum],
|
||||
)
|
||||
value_factory_spy = mocker.spy(
|
||||
ThresholdDecryptionClient.ThresholdDecryptionRequestFactory, "__init__"
|
||||
)
|
||||
|
||||
# ritual_id, ciphertext, conditions are obtained from the side channel
|
||||
bob.start_learning_loop(now=True)
|
||||
cleartext = yield bob.threshold_decrypt(
|
||||
threshold_message_kit=threshold_message_kit,
|
||||
)
|
||||
assert bytes(cleartext) == plaintext.encode()
|
||||
|
||||
# check that proper ordering of ursulas used for worker pool factory for requests
|
||||
value_factory_spy.assert_called_once_with(
|
||||
ANY,
|
||||
ursulas_to_contact=expected_ursula_request_ordering,
|
||||
batch_size=ANY,
|
||||
threshold=ANY,
|
||||
)
|
||||
|
||||
# check prometheus metric for decryption requests
|
||||
# since all running on the same machine - the value is not per-ursula but rather all
|
||||
num_successes = REGISTRY.get_sample_value(
|
||||
"threshold_decryption_num_successes_total"
|
||||
)
|
||||
|
||||
ritual = coordinator_agent.get_ritual(ritual_id)
|
||||
# at least a threshold of ursulas were successful (concurrency)
|
||||
assert int(num_successes) >= ritual.threshold
|
||||
print("===================== DECRYPTION SUCCESSFUL =====================")
|
||||
yield
|
||||
|
||||
|
||||
@pytest_twisted.inlineCallbacks
|
||||
def test_decrypt_without_any_cached_values(
|
||||
threshold_message_kit, ritual_id, cohort, bob, coordinator_agent, plaintext
|
||||
):
|
||||
print("==================== DKG DECRYPTION NO CACHE ====================")
|
||||
original_validators = cohort[0].dkg_storage.get_validators(ritual_id)
|
||||
|
||||
for ursula in cohort:
|
||||
ursula.dkg_storage.clear(ritual_id)
|
||||
assert ursula.dkg_storage.get_validators(ritual_id) is None
|
||||
assert ursula.dkg_storage.get_active_ritual(ritual_id) is None
|
||||
|
||||
# perform threshold decryption
|
||||
bob.start_learning_loop(now=True)
|
||||
cleartext = bob.threshold_decrypt(
|
||||
cleartext = yield bob.threshold_decrypt(
|
||||
threshold_message_kit=threshold_message_kit,
|
||||
)
|
||||
assert bytes(cleartext) == plaintext.encode()
|
||||
|
@ -308,45 +378,7 @@ def check_decrypt_without_any_cached_values(
|
|||
assert v.public_key == original_validators[v_index].public_key
|
||||
|
||||
assert num_used_ursulas >= ritual.threshold
|
||||
print("===================== DECRYPTION SUCCESSFUL =====================")
|
||||
|
||||
|
||||
@pytest_twisted.inlineCallbacks
|
||||
def test_authorized_decryption(
|
||||
bob,
|
||||
global_allow_list,
|
||||
accounts,
|
||||
coordinator_agent,
|
||||
threshold_message_kit,
|
||||
signer,
|
||||
initiator,
|
||||
ritual_id,
|
||||
plaintext,
|
||||
):
|
||||
print("==================== DKG DECRYPTION (AUTHORIZED) ====================")
|
||||
# authorize Enrico to encrypt for ritual
|
||||
global_allow_list.authorize(
|
||||
ritual_id,
|
||||
[signer.accounts[0]],
|
||||
sender=accounts[initiator.transacting_power.account],
|
||||
)
|
||||
|
||||
# ritual_id, ciphertext, conditions are obtained from the side channel
|
||||
bob.start_learning_loop(now=True)
|
||||
cleartext = yield bob.threshold_decrypt(
|
||||
threshold_message_kit=threshold_message_kit,
|
||||
)
|
||||
assert bytes(cleartext) == plaintext.encode()
|
||||
|
||||
# check prometheus metric for decryption requests
|
||||
# since all running on the same machine - the value is not per-ursula but rather all
|
||||
num_successes = REGISTRY.get_sample_value(
|
||||
"threshold_decryption_num_successes_total"
|
||||
)
|
||||
|
||||
ritual = coordinator_agent.get_ritual(ritual_id)
|
||||
# at least a threshold of ursulas were successful (concurrency)
|
||||
assert int(num_successes) >= ritual.threshold
|
||||
print("===================== DECRYPTION NO CACHE SUCCESSFUL =====================")
|
||||
yield
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,297 @@
|
|||
import math
|
||||
import random
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from eth_typing import ChecksumAddress
|
||||
|
||||
from nucypher.utilities.latency import NodeLatencyStatsCollector
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def execution_data(get_random_checksum_address):
|
||||
executions = {}
|
||||
|
||||
node_1 = get_random_checksum_address()
|
||||
node_1_exec_times = [11.23, 24.8, 31.5, 40.21]
|
||||
executions[node_1] = node_1_exec_times
|
||||
|
||||
node_2 = get_random_checksum_address()
|
||||
node_2_exec_times = [5.03, 6.78, 7.42, 8.043]
|
||||
executions[node_2] = node_2_exec_times
|
||||
|
||||
node_3 = get_random_checksum_address()
|
||||
node_3_exec_times = [0.44, 4.512, 3.3]
|
||||
executions[node_3] = node_3_exec_times
|
||||
|
||||
node_4 = get_random_checksum_address()
|
||||
node_4_exec_times = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
executions[node_4] = node_4_exec_times
|
||||
|
||||
sorted_order = sorted(
|
||||
list(executions.keys()),
|
||||
key=lambda x: sum(executions.get(x)) / len(executions.get(x)),
|
||||
)
|
||||
assert sorted_order == [
|
||||
node_4,
|
||||
node_3,
|
||||
node_2,
|
||||
node_1,
|
||||
] # test of the test - "that's sooo meta"
|
||||
|
||||
return executions, sorted_order
|
||||
|
||||
|
||||
def test_collector_initialization_no_data_collected(get_random_checksum_address):
|
||||
node_latency_collector = NodeLatencyStatsCollector()
|
||||
|
||||
staker_addresses = [get_random_checksum_address() for _ in range(4)]
|
||||
|
||||
# no data collected so average equals maximum latency
|
||||
for staker_address in staker_addresses:
|
||||
assert (
|
||||
node_latency_collector.get_average_latency_time(staker_address)
|
||||
== NodeLatencyStatsCollector.MAX_LATENCY
|
||||
)
|
||||
|
||||
# no data collected so no change in order
|
||||
assert (
|
||||
node_latency_collector.order_addresses_by_latency(staker_addresses)
|
||||
== staker_addresses
|
||||
)
|
||||
|
||||
|
||||
def test_collector_stats_obtained(execution_data):
|
||||
executions, expected_node_sorted_order = execution_data
|
||||
node_latency_collector = NodeLatencyStatsCollector()
|
||||
|
||||
# update stats for all nodes
|
||||
for node, execution_times in executions.items():
|
||||
for i, exec_time in enumerate(execution_times):
|
||||
node_latency_collector._update_stats(node, exec_time)
|
||||
|
||||
# check ongoing average
|
||||
subset_of_times = execution_times[: (i + 1)]
|
||||
# floating point arithmetic makes an exact check tricky
|
||||
assert math.isclose(
|
||||
node_latency_collector.get_average_latency_time(node),
|
||||
sum(subset_of_times) / len(subset_of_times),
|
||||
)
|
||||
|
||||
# check final average
|
||||
# floating point arithmetic makes an exact check tricky
|
||||
assert math.isclose(
|
||||
node_latency_collector.get_average_latency_time(node),
|
||||
sum(execution_times) / len(execution_times),
|
||||
)
|
||||
|
||||
node_addresses = list(executions.keys())
|
||||
for _ in range(10):
|
||||
# try various random permutations of order
|
||||
random.shuffle(node_addresses)
|
||||
assert (
|
||||
node_latency_collector.order_addresses_by_latency(node_addresses)
|
||||
== expected_node_sorted_order
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"max_window", [NodeLatencyStatsCollector.MAX_MOVING_AVERAGE_WINDOW, 3, 7, 15]
|
||||
)
|
||||
def test_collector_moving_average_window(max_window, get_random_checksum_address):
|
||||
node_collector_stats = NodeLatencyStatsCollector(
|
||||
max_moving_average_window=max_window
|
||||
)
|
||||
node = get_random_checksum_address()
|
||||
exec_times = []
|
||||
|
||||
# <= moving average window
|
||||
for i in range(max_window):
|
||||
value = random.uniform(0, 40)
|
||||
exec_times.append(value)
|
||||
node_collector_stats._update_stats(node, value)
|
||||
# all values available used
|
||||
assert math.isclose(
|
||||
node_collector_stats.get_average_latency_time(node),
|
||||
sum(exec_times) / len(exec_times),
|
||||
)
|
||||
|
||||
# > moving average window
|
||||
for i in range(max_window * 2):
|
||||
value = random.uniform(0, 40)
|
||||
exec_times.append(value)
|
||||
node_collector_stats._update_stats(node, value)
|
||||
|
||||
# only the latest "max_window" values are used for the average, even though additional values were collected
|
||||
assert math.isclose(
|
||||
node_collector_stats.get_average_latency_time(node),
|
||||
sum(exec_times[-max_window:]) / len(exec_times[-max_window:]),
|
||||
)
|
||||
|
||||
|
||||
def test_collector_stats_reset(execution_data):
|
||||
executions, original_expected_node_sorted_order = execution_data
|
||||
node_latency_collector = NodeLatencyStatsCollector()
|
||||
|
||||
# update stats for all nodes
|
||||
for node, execution_times in executions.items():
|
||||
for exec_time in execution_times:
|
||||
node_latency_collector._update_stats(node, exec_time)
|
||||
|
||||
assert math.isclose(
|
||||
node_latency_collector.get_average_latency_time(node),
|
||||
sum(execution_times) / len(execution_times),
|
||||
)
|
||||
|
||||
# proper order
|
||||
assert (
|
||||
node_latency_collector.order_addresses_by_latency(list(executions.keys()))
|
||||
== original_expected_node_sorted_order
|
||||
)
|
||||
|
||||
# reset stats for fastest node, in which case it should now move to the end of the ordered list
|
||||
node_latency_collector.reset_stats(original_expected_node_sorted_order[0])
|
||||
assert (
|
||||
node_latency_collector.get_average_latency_time(
|
||||
original_expected_node_sorted_order[0]
|
||||
)
|
||||
== NodeLatencyStatsCollector.MAX_LATENCY
|
||||
)
|
||||
|
||||
updated_order = original_expected_node_sorted_order[1:] + [
|
||||
original_expected_node_sorted_order[0]
|
||||
]
|
||||
assert updated_order != original_expected_node_sorted_order
|
||||
assert (
|
||||
node_latency_collector.order_addresses_by_latency(list(executions.keys()))
|
||||
== updated_order
|
||||
)
|
||||
|
||||
# reset another node's stats
|
||||
node_latency_collector.reset_stats(updated_order[1])
|
||||
assert (
|
||||
node_latency_collector.get_average_latency_time(updated_order[1])
|
||||
== NodeLatencyStatsCollector.MAX_LATENCY
|
||||
)
|
||||
# the order the addresses are passed in dictates the order of nodes without stats
|
||||
expected_updated_updated_order = (
|
||||
[updated_order[0]] + updated_order[2:-1] + [updated_order[1], updated_order[3]]
|
||||
)
|
||||
assert (
|
||||
node_latency_collector.order_addresses_by_latency(updated_order)
|
||||
== expected_updated_updated_order
|
||||
)
|
||||
|
||||
# reset all stats
|
||||
for node in executions.keys():
|
||||
node_latency_collector.reset_stats(node)
|
||||
assert (
|
||||
node_latency_collector.get_average_latency_time(node)
|
||||
== NodeLatencyStatsCollector.MAX_LATENCY
|
||||
)
|
||||
all_reset_order = list(executions.keys())
|
||||
assert (
|
||||
node_latency_collector.order_addresses_by_latency(all_reset_order)
|
||||
== all_reset_order
|
||||
)
|
||||
|
||||
|
||||
def test_collector_simple_concurrency(execution_data):
|
||||
executions, expected_node_sorted_order = execution_data
|
||||
node_latency_collector = NodeLatencyStatsCollector()
|
||||
|
||||
def populate_executions(node_address: ChecksumAddress):
|
||||
execution_times = executions[node_address]
|
||||
for exec_time in execution_times:
|
||||
# add some delay for better concurrency
|
||||
time.sleep(0.1)
|
||||
node_latency_collector._update_stats(node_address, exec_time)
|
||||
|
||||
# use thread pool
|
||||
n_threads = len(executions)
|
||||
with ThreadPoolExecutor(n_threads) as executor:
|
||||
# download each url and save as a local file
|
||||
futures = []
|
||||
for node_address in executions.keys():
|
||||
f = executor.submit(populate_executions, node_address)
|
||||
futures.append(f)
|
||||
|
||||
wait(futures, timeout=3) # these shouldn't take long; only wait max 3s
|
||||
|
||||
assert (
|
||||
node_latency_collector.order_addresses_by_latency(list(executions.keys()))
|
||||
== expected_node_sorted_order
|
||||
)
|
||||
|
||||
|
||||
def test_collector_tracker_no_exception(execution_data):
|
||||
executions, expected_node_sorted_order = execution_data
|
||||
node_latency_collector = NodeLatencyStatsCollector()
|
||||
for node, execution_times in executions.items():
|
||||
for exec_time in execution_times:
|
||||
base_perf_counter = time.perf_counter()
|
||||
end_time = base_perf_counter + exec_time
|
||||
with patch("time.perf_counter", side_effect=[base_perf_counter, end_time]):
|
||||
with node_latency_collector.get_latency_tracker(node):
|
||||
# fake execution; do nothing
|
||||
time.sleep(0)
|
||||
|
||||
# floating point arithmetic makes an exact check tricky
|
||||
assert math.isclose(
|
||||
node_latency_collector.get_average_latency_time(node),
|
||||
sum(execution_times) / len(execution_times),
|
||||
)
|
||||
|
||||
node_addresses = list(executions.keys())
|
||||
for _ in range(10):
|
||||
# try various random permutations of order
|
||||
random.shuffle(node_addresses)
|
||||
assert (
|
||||
node_latency_collector.order_addresses_by_latency(node_addresses)
|
||||
== expected_node_sorted_order
|
||||
)
|
||||
|
||||
|
||||
def test_collector_tracker_exception(execution_data):
|
||||
executions, _ = execution_data
|
||||
node_latency_collector = NodeLatencyStatsCollector()
|
||||
|
||||
node_not_to_raise = random.sample(list(executions.keys()), 1)[0]
|
||||
for node, execution_times in executions.items():
|
||||
for exec_time in execution_times:
|
||||
base_perf_counter = time.perf_counter()
|
||||
end_time = base_perf_counter + exec_time
|
||||
with patch("time.perf_counter", side_effect=[base_perf_counter, end_time]):
|
||||
exception_propagated = False
|
||||
try:
|
||||
with node_latency_collector.get_latency_tracker(node):
|
||||
# raise exception during whatever node execution
|
||||
if node != node_not_to_raise:
|
||||
raise ConnectionRefusedError("random execution exception")
|
||||
except ConnectionRefusedError:
|
||||
exception_propagated = True
|
||||
|
||||
assert exception_propagated == (node != node_not_to_raise)
|
||||
|
||||
if node != node_not_to_raise:
|
||||
# no stats stored so average equals MAX_LATENCY
|
||||
assert (
|
||||
node_latency_collector.get_average_latency_time(node)
|
||||
== NodeLatencyStatsCollector.MAX_LATENCY
|
||||
)
|
||||
else:
|
||||
# floating point arithmetic makes an exact check tricky
|
||||
assert math.isclose(
|
||||
node_latency_collector.get_average_latency_time(node_not_to_raise),
|
||||
sum(execution_times) / len(execution_times),
|
||||
)
|
||||
|
||||
node_addresses = list(executions.keys())
|
||||
exp_sorted_addresses = [node_not_to_raise] + [
|
||||
a for a in node_addresses if a != node_not_to_raise
|
||||
]
|
||||
sorted_addresses = node_latency_collector.order_addresses_by_latency(node_addresses)
|
||||
assert sorted_addresses[0] == node_not_to_raise
|
||||
assert sorted_addresses == exp_sorted_addresses
|
Loading…
Reference in New Issue