Add utility to Learner for tracking node connection latency.

pull/3562/head
derekpierre 2024-10-31 13:34:17 -04:00
parent 2d85ccf07b
commit fc17bbfc53
No known key found for this signature in database
1 changed files with 79 additions and 7 deletions

View File

@ -1,8 +1,9 @@
import time
from collections import deque
from collections import defaultdict, deque
from contextlib import suppress
from queue import Queue
from typing import Optional, Set, Tuple
from threading import Lock
from typing import List, Optional, Set, Tuple
import maya
import requests
@ -53,6 +54,72 @@ TEACHER_NODES = {
}
class NodeLatencyStatsCollector:
TOTAL_TIME = "total_time"
COUNT = "count"
MAX_LATENCY = 2**32 - 1 # just need a large number
class NodeExecutionContextManager:
def __init__(
self,
stats_collector: "NodeLatencyStatsCollector",
staker_address: str,
):
self._stats_collector = stats_collector
self.staker_address = staker_address
def __enter__(self):
self.start_time = time.perf_counter()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type:
# exception occurred - reset stats since connectivity was compromised
self._stats_collector.reset_stats(self.staker_address)
else:
# no exception
end_time = time.perf_counter()
execution_time = end_time - self.start_time
self._stats_collector.update_stats(self.staker_address, execution_time)
def __init__(self):
# staker_address -> { "total_time": <float>, "count": <integer> }
self._node_stats = defaultdict(lambda: {self.TOTAL_TIME: 0.0, self.COUNT: 0})
self._lock = Lock()
def update_stats(self, staking_address: str, latest_time_taken: float):
with self._lock:
self._node_stats[staking_address][self.TOTAL_TIME] += latest_time_taken
self._node_stats[staking_address][self.COUNT] += 1
def reset_stats(self, staking_address: str):
with self._lock:
self._node_stats[staking_address][self.TOTAL_TIME] = 0
self._node_stats[staking_address][self.COUNT] = 0
def get_latency_tracker(self, staker_address: str) -> NodeExecutionContextManager:
return self.NodeExecutionContextManager(
stats_collector=self, staker_address=staker_address
)
def get_average_latency_time(self, staking_address: str):
with self._lock:
count = self._node_stats[staking_address][self.COUNT]
# just need a large number > 0
return (
self.MAX_LATENCY
if count == 0
else self._node_stats[staking_address][self.TOTAL_TIME] / count
)
def order_addresses_by_latency(self, staking_addresses: List[str]):
result = sorted(
staking_addresses,
key=lambda staking_address: self.get_average_latency_time(staking_address),
)
return result
class NodeSprout:
"""
An abridged node class designed for optimization of instantiation of > 100 nodes simultaneously.
@ -261,6 +328,8 @@ class Learner:
self.log = Logger("learning-loop") # type: Logger
self.domain = domain
self.node_latency_collector = NodeLatencyStatsCollector()
self.learning_deferred = Deferred()
default_middleware = self.__DEFAULT_MIDDLEWARE_CLASS(
registry=self.registry, eth_endpoint=self.eth_endpoint
@ -827,11 +896,14 @@ class Learner:
return RELAX
try:
response = self.network_middleware.get_nodes_via_rest(
node=current_teacher,
announce_nodes=announce_nodes,
fleet_state_checksum=self.known_nodes.checksum,
)
with self.node_latency_collector.get_latency_tracker(
current_teacher.checksum_address
):
response = self.network_middleware.get_nodes_via_rest(
node=current_teacher,
announce_nodes=announce_nodes,
fleet_state_checksum=self.known_nodes.checksum,
)
# These except clauses apply to the current_teacher itself, not the learned-about nodes.
except NodeSeemsToBeDown as e:
unresponsive_nodes.add(current_teacher)