Merge pull request #3562 from derekpierre/latency-tracking

Learner node latency tracking
pull/3563/head
Derek Pierre 2024-11-06 11:03:55 -05:00 committed by GitHub
commit a78ac58819
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 482 additions and 52 deletions

View File

@ -0,0 +1 @@
Enhance threshold decryption request efficiency by prioritizing nodes in the cohort with lower communication latency.

View File

@ -387,6 +387,8 @@ class Alice(Character, actors.PolicyAuthor):
class Bob(Character):
_TRACK_NODE_LATENCY_STATS = True
banner = BOB_BANNER
_default_dkg_variant = FerveoVariant.Simple
_default_crypto_powerups = [SigningPower, DecryptingPower]

View File

@ -1,6 +1,5 @@
import math
from http import HTTPStatus
from random import shuffle
from typing import Dict, List, Tuple
from eth_typing import ChecksumAddress
@ -80,15 +79,20 @@ class ThresholdDecryptionClient(ThresholdAccessControlClient):
self.log.warn(message)
raise self.ThresholdDecryptionRequestFailed(message)
# TODO: Find a better request order, perhaps based on latency data obtained from discovery loop - #3395
requests = list(encrypted_requests)
shuffle(requests)
ursulas_to_contact = (
self._learner.node_latency_collector.order_addresses_by_latency(
list(encrypted_requests)
)
if self._learner.node_latency_collector
else list(encrypted_requests)
)
# Discussion about WorkerPool parameters:
# "https://github.com/nucypher/nucypher/pull/3393#discussion_r1456307991"
worker_pool = WorkerPool(
worker=worker,
value_factory=self.ThresholdDecryptionRequestFactory(
ursulas_to_contact=requests,
ursulas_to_contact=ursulas_to_contact,
batch_size=math.ceil(threshold * 1.25),
threshold=threshold,
),

View File

@ -1,3 +1,4 @@
import contextlib
import time
from collections import deque
from contextlib import suppress
@ -40,6 +41,7 @@ from nucypher.crypto.signing import InvalidSignature, SignatureStamp
from nucypher.network.exceptions import NodeSeemsToBeDown
from nucypher.network.middleware import RestMiddleware
from nucypher.network.protocols import InterfaceInfo, SuspiciousActivity
from nucypher.utilities.latency import NodeLatencyStatsCollector
from nucypher.utilities.logging import Logger
TEACHER_NODES = {
@ -218,6 +220,8 @@ class Learner:
_ROUNDS_WITHOUT_NODES_AFTER_WHICH_TO_SLOW_DOWN = 10
__DEFAULT_MIDDLEWARE_CLASS = RestMiddleware
_TRACK_NODE_LATENCY_STATS = False
_crashed = (
False # moved from Character - why was this in Character and not Learner before
)
@ -261,6 +265,10 @@ class Learner:
self.log = Logger("learning-loop") # type: Logger
self.domain = domain
self.node_latency_collector = (
NodeLatencyStatsCollector() if self._TRACK_NODE_LATENCY_STATS else None
)
self.learning_deferred = Deferred()
default_middleware = self.__DEFAULT_MIDDLEWARE_CLASS(
registry=self.registry, eth_endpoint=self.eth_endpoint
@ -827,11 +835,19 @@ class Learner:
return RELAX
try:
response = self.network_middleware.get_nodes_via_rest(
node=current_teacher,
announce_nodes=announce_nodes,
fleet_state_checksum=self.known_nodes.checksum,
optional_latency_context_manager = (
self.node_latency_collector.get_latency_tracker(
current_teacher.checksum_address
)
if self.node_latency_collector
else contextlib.nullcontext()
)
with optional_latency_context_manager:
response = self.network_middleware.get_nodes_via_rest(
node=current_teacher,
announce_nodes=announce_nodes,
fleet_state_checksum=self.known_nodes.checksum,
)
# These except clauses apply to the current_teacher itself, not the learned-about nodes.
except NodeSeemsToBeDown as e:
unresponsive_nodes.add(current_teacher)

View File

@ -0,0 +1,78 @@
import collections
import time
from collections import defaultdict
from threading import Lock
from typing import List
from eth_typing import ChecksumAddress
class NodeLatencyStatsCollector:
"""
Thread-safe utility that tracks latency statistics related to P2P connections with other nodes.
"""
MAX_MOVING_AVERAGE_WINDOW = 5
MAX_LATENCY = float(2**16) # just need a large number for sorting
class NodeLatencyContextManager:
def __init__(
self,
stats_collector: "NodeLatencyStatsCollector",
staker_address: ChecksumAddress,
):
self._stats_collector = stats_collector
self.staker_address = staker_address
def __enter__(self):
self.start_time = time.perf_counter()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type:
# exception occurred - reset stats since connectivity was compromised
self._stats_collector.reset_stats(self.staker_address)
else:
# no exception
end_time = time.perf_counter()
execution_time = end_time - self.start_time
self._stats_collector._update_stats(self.staker_address, execution_time)
def __init__(self, max_moving_average_window: int = MAX_MOVING_AVERAGE_WINDOW):
self._node_stats = defaultdict(
lambda: collections.deque([], maxlen=max_moving_average_window)
)
self._lock = Lock()
def _update_stats(self, staking_address: ChecksumAddress, latest_time_taken: float):
with self._lock:
self._node_stats[staking_address].append(latest_time_taken)
def reset_stats(self, staking_address: ChecksumAddress):
with self._lock:
self._node_stats[staking_address].clear()
def get_latency_tracker(
self, staker_address: ChecksumAddress
) -> NodeLatencyContextManager:
return self.NodeLatencyContextManager(
stats_collector=self, staker_address=staker_address
)
def get_average_latency_time(self, staking_address: ChecksumAddress) -> float:
with self._lock:
readings = list(self._node_stats[staking_address])
num_readings = len(readings)
# just need a large number > 0
return (
self.MAX_LATENCY if num_readings == 0 else sum(readings) / num_readings
)
def order_addresses_by_latency(
self, staking_addresses: List[ChecksumAddress]
) -> List[ChecksumAddress]:
result = sorted(
staking_addresses,
key=lambda staking_address: self.get_average_latency_time(staking_address),
)
return result

View File

@ -1,6 +1,6 @@
import os
import random
from unittest.mock import patch
from unittest.mock import ANY, patch
import pytest
import pytest_twisted
@ -12,6 +12,7 @@ from nucypher.blockchain.eth.constants import NULL_ADDRESS
from nucypher.blockchain.eth.models import Coordinator
from nucypher.blockchain.eth.signers.software import InMemorySigner
from nucypher.characters.lawful import Enrico, Ursula
from nucypher.network.decryption import ThresholdDecryptionClient
from nucypher.policy.conditions.evm import ContractCondition, RPCCondition
from nucypher.policy.conditions.lingo import (
ConditionLingo,
@ -275,19 +276,88 @@ def test_unauthorized_decryption(
yield
def check_decrypt_without_any_cached_values(
@pytest_twisted.inlineCallbacks
def test_authorized_decryption(
mocker,
bob,
global_allow_list,
accounts,
coordinator_agent,
threshold_message_kit,
signer,
initiator,
ritual_id,
cohort,
plaintext,
):
print("==================== DKG DECRYPTION (AUTHORIZED) ====================")
# authorize Enrico to encrypt for ritual
global_allow_list.authorize(
ritual_id,
[signer.accounts[0]],
sender=accounts[initiator.transacting_power.account],
)
# fake some latency stats
latency_stats = {}
for ursula in cohort:
# reset all stats
bob.node_latency_collector.reset_stats(ursula.checksum_address)
# add a single data point for each ursula: some time between 0.1 and 4
mock_latency = random.uniform(0.1, 4)
bob.node_latency_collector._update_stats(ursula.checksum_address, mock_latency)
latency_stats[ursula.checksum_address] = mock_latency
expected_ursula_request_ordering = sorted(
list(latency_stats.keys()),
key=lambda ursula_checksum: latency_stats[ursula_checksum],
)
value_factory_spy = mocker.spy(
ThresholdDecryptionClient.ThresholdDecryptionRequestFactory, "__init__"
)
# ritual_id, ciphertext, conditions are obtained from the side channel
bob.start_learning_loop(now=True)
cleartext = yield bob.threshold_decrypt(
threshold_message_kit=threshold_message_kit,
)
assert bytes(cleartext) == plaintext.encode()
# check that proper ordering of ursulas used for worker pool factory for requests
value_factory_spy.assert_called_once_with(
ANY,
ursulas_to_contact=expected_ursula_request_ordering,
batch_size=ANY,
threshold=ANY,
)
# check prometheus metric for decryption requests
# since all running on the same machine - the value is not per-ursula but rather all
num_successes = REGISTRY.get_sample_value(
"threshold_decryption_num_successes_total"
)
ritual = coordinator_agent.get_ritual(ritual_id)
# at least a threshold of ursulas were successful (concurrency)
assert int(num_successes) >= ritual.threshold
print("===================== DECRYPTION SUCCESSFUL =====================")
yield
@pytest_twisted.inlineCallbacks
def test_decrypt_without_any_cached_values(
threshold_message_kit, ritual_id, cohort, bob, coordinator_agent, plaintext
):
print("==================== DKG DECRYPTION NO CACHE ====================")
original_validators = cohort[0].dkg_storage.get_validators(ritual_id)
for ursula in cohort:
ursula.dkg_storage.clear(ritual_id)
assert ursula.dkg_storage.get_validators(ritual_id) is None
assert ursula.dkg_storage.get_active_ritual(ritual_id) is None
# perform threshold decryption
bob.start_learning_loop(now=True)
cleartext = bob.threshold_decrypt(
cleartext = yield bob.threshold_decrypt(
threshold_message_kit=threshold_message_kit,
)
assert bytes(cleartext) == plaintext.encode()
@ -308,45 +378,7 @@ def check_decrypt_without_any_cached_values(
assert v.public_key == original_validators[v_index].public_key
assert num_used_ursulas >= ritual.threshold
print("===================== DECRYPTION SUCCESSFUL =====================")
@pytest_twisted.inlineCallbacks
def test_authorized_decryption(
bob,
global_allow_list,
accounts,
coordinator_agent,
threshold_message_kit,
signer,
initiator,
ritual_id,
plaintext,
):
print("==================== DKG DECRYPTION (AUTHORIZED) ====================")
# authorize Enrico to encrypt for ritual
global_allow_list.authorize(
ritual_id,
[signer.accounts[0]],
sender=accounts[initiator.transacting_power.account],
)
# ritual_id, ciphertext, conditions are obtained from the side channel
bob.start_learning_loop(now=True)
cleartext = yield bob.threshold_decrypt(
threshold_message_kit=threshold_message_kit,
)
assert bytes(cleartext) == plaintext.encode()
# check prometheus metric for decryption requests
# since all running on the same machine - the value is not per-ursula but rather all
num_successes = REGISTRY.get_sample_value(
"threshold_decryption_num_successes_total"
)
ritual = coordinator_agent.get_ritual(ritual_id)
# at least a threshold of ursulas were successful (concurrency)
assert int(num_successes) >= ritual.threshold
print("===================== DECRYPTION NO CACHE SUCCESSFUL =====================")
yield

View File

@ -0,0 +1,297 @@
import math
import random
import time
from concurrent.futures import ThreadPoolExecutor, wait
from unittest.mock import patch
import pytest
from eth_typing import ChecksumAddress
from nucypher.utilities.latency import NodeLatencyStatsCollector
@pytest.fixture(scope="module")
def execution_data(get_random_checksum_address):
executions = {}
node_1 = get_random_checksum_address()
node_1_exec_times = [11.23, 24.8, 31.5, 40.21]
executions[node_1] = node_1_exec_times
node_2 = get_random_checksum_address()
node_2_exec_times = [5.03, 6.78, 7.42, 8.043]
executions[node_2] = node_2_exec_times
node_3 = get_random_checksum_address()
node_3_exec_times = [0.44, 4.512, 3.3]
executions[node_3] = node_3_exec_times
node_4 = get_random_checksum_address()
node_4_exec_times = [0.1, 0.2, 0.3, 0.4, 0.5]
executions[node_4] = node_4_exec_times
sorted_order = sorted(
list(executions.keys()),
key=lambda x: sum(executions.get(x)) / len(executions.get(x)),
)
assert sorted_order == [
node_4,
node_3,
node_2,
node_1,
] # test of the test - "that's sooo meta"
return executions, sorted_order
def test_collector_initialization_no_data_collected(get_random_checksum_address):
node_latency_collector = NodeLatencyStatsCollector()
staker_addresses = [get_random_checksum_address() for _ in range(4)]
# no data collected so average equals maximum latency
for staker_address in staker_addresses:
assert (
node_latency_collector.get_average_latency_time(staker_address)
== NodeLatencyStatsCollector.MAX_LATENCY
)
# no data collected so no change in order
assert (
node_latency_collector.order_addresses_by_latency(staker_addresses)
== staker_addresses
)
def test_collector_stats_obtained(execution_data):
executions, expected_node_sorted_order = execution_data
node_latency_collector = NodeLatencyStatsCollector()
# update stats for all nodes
for node, execution_times in executions.items():
for i, exec_time in enumerate(execution_times):
node_latency_collector._update_stats(node, exec_time)
# check ongoing average
subset_of_times = execution_times[: (i + 1)]
# floating point arithmetic makes an exact check tricky
assert math.isclose(
node_latency_collector.get_average_latency_time(node),
sum(subset_of_times) / len(subset_of_times),
)
# check final average
# floating point arithmetic makes an exact check tricky
assert math.isclose(
node_latency_collector.get_average_latency_time(node),
sum(execution_times) / len(execution_times),
)
node_addresses = list(executions.keys())
for _ in range(10):
# try various random permutations of order
random.shuffle(node_addresses)
assert (
node_latency_collector.order_addresses_by_latency(node_addresses)
== expected_node_sorted_order
)
@pytest.mark.parametrize(
"max_window", [NodeLatencyStatsCollector.MAX_MOVING_AVERAGE_WINDOW, 3, 7, 15]
)
def test_collector_moving_average_window(max_window, get_random_checksum_address):
node_collector_stats = NodeLatencyStatsCollector(
max_moving_average_window=max_window
)
node = get_random_checksum_address()
exec_times = []
# <= moving average window
for i in range(max_window):
value = random.uniform(0, 40)
exec_times.append(value)
node_collector_stats._update_stats(node, value)
# all values available used
assert math.isclose(
node_collector_stats.get_average_latency_time(node),
sum(exec_times) / len(exec_times),
)
# > moving average window
for i in range(max_window * 2):
value = random.uniform(0, 40)
exec_times.append(value)
node_collector_stats._update_stats(node, value)
# only the latest "max_window" values are used for the average, even though additional values were collected
assert math.isclose(
node_collector_stats.get_average_latency_time(node),
sum(exec_times[-max_window:]) / len(exec_times[-max_window:]),
)
def test_collector_stats_reset(execution_data):
executions, original_expected_node_sorted_order = execution_data
node_latency_collector = NodeLatencyStatsCollector()
# update stats for all nodes
for node, execution_times in executions.items():
for exec_time in execution_times:
node_latency_collector._update_stats(node, exec_time)
assert math.isclose(
node_latency_collector.get_average_latency_time(node),
sum(execution_times) / len(execution_times),
)
# proper order
assert (
node_latency_collector.order_addresses_by_latency(list(executions.keys()))
== original_expected_node_sorted_order
)
# reset stats for fastest node, in which case it should now move to the end of the ordered list
node_latency_collector.reset_stats(original_expected_node_sorted_order[0])
assert (
node_latency_collector.get_average_latency_time(
original_expected_node_sorted_order[0]
)
== NodeLatencyStatsCollector.MAX_LATENCY
)
updated_order = original_expected_node_sorted_order[1:] + [
original_expected_node_sorted_order[0]
]
assert updated_order != original_expected_node_sorted_order
assert (
node_latency_collector.order_addresses_by_latency(list(executions.keys()))
== updated_order
)
# reset another node's stats
node_latency_collector.reset_stats(updated_order[1])
assert (
node_latency_collector.get_average_latency_time(updated_order[1])
== NodeLatencyStatsCollector.MAX_LATENCY
)
# the order the addresses are passed in dictates the order of nodes without stats
expected_updated_updated_order = (
[updated_order[0]] + updated_order[2:-1] + [updated_order[1], updated_order[3]]
)
assert (
node_latency_collector.order_addresses_by_latency(updated_order)
== expected_updated_updated_order
)
# reset all stats
for node in executions.keys():
node_latency_collector.reset_stats(node)
assert (
node_latency_collector.get_average_latency_time(node)
== NodeLatencyStatsCollector.MAX_LATENCY
)
all_reset_order = list(executions.keys())
assert (
node_latency_collector.order_addresses_by_latency(all_reset_order)
== all_reset_order
)
def test_collector_simple_concurrency(execution_data):
executions, expected_node_sorted_order = execution_data
node_latency_collector = NodeLatencyStatsCollector()
def populate_executions(node_address: ChecksumAddress):
execution_times = executions[node_address]
for exec_time in execution_times:
# add some delay for better concurrency
time.sleep(0.1)
node_latency_collector._update_stats(node_address, exec_time)
# use thread pool
n_threads = len(executions)
with ThreadPoolExecutor(n_threads) as executor:
# download each url and save as a local file
futures = []
for node_address in executions.keys():
f = executor.submit(populate_executions, node_address)
futures.append(f)
wait(futures, timeout=3) # these shouldn't take long; only wait max 3s
assert (
node_latency_collector.order_addresses_by_latency(list(executions.keys()))
== expected_node_sorted_order
)
def test_collector_tracker_no_exception(execution_data):
executions, expected_node_sorted_order = execution_data
node_latency_collector = NodeLatencyStatsCollector()
for node, execution_times in executions.items():
for exec_time in execution_times:
base_perf_counter = time.perf_counter()
end_time = base_perf_counter + exec_time
with patch("time.perf_counter", side_effect=[base_perf_counter, end_time]):
with node_latency_collector.get_latency_tracker(node):
# fake execution; do nothing
time.sleep(0)
# floating point arithmetic makes an exact check tricky
assert math.isclose(
node_latency_collector.get_average_latency_time(node),
sum(execution_times) / len(execution_times),
)
node_addresses = list(executions.keys())
for _ in range(10):
# try various random permutations of order
random.shuffle(node_addresses)
assert (
node_latency_collector.order_addresses_by_latency(node_addresses)
== expected_node_sorted_order
)
def test_collector_tracker_exception(execution_data):
executions, _ = execution_data
node_latency_collector = NodeLatencyStatsCollector()
node_not_to_raise = random.sample(list(executions.keys()), 1)[0]
for node, execution_times in executions.items():
for exec_time in execution_times:
base_perf_counter = time.perf_counter()
end_time = base_perf_counter + exec_time
with patch("time.perf_counter", side_effect=[base_perf_counter, end_time]):
exception_propagated = False
try:
with node_latency_collector.get_latency_tracker(node):
# raise exception during whatever node execution
if node != node_not_to_raise:
raise ConnectionRefusedError("random execution exception")
except ConnectionRefusedError:
exception_propagated = True
assert exception_propagated == (node != node_not_to_raise)
if node != node_not_to_raise:
# no stats stored so average equals MAX_LATENCY
assert (
node_latency_collector.get_average_latency_time(node)
== NodeLatencyStatsCollector.MAX_LATENCY
)
else:
# floating point arithmetic makes an exact check tricky
assert math.isclose(
node_latency_collector.get_average_latency_time(node_not_to_raise),
sum(execution_times) / len(execution_times),
)
node_addresses = list(executions.keys())
exp_sorted_addresses = [node_not_to_raise] + [
a for a in node_addresses if a != node_not_to_raise
]
sorted_addresses = node_latency_collector.order_addresses_by_latency(node_addresses)
assert sorted_addresses[0] == node_not_to_raise
assert sorted_addresses == exp_sorted_addresses