From bb947185145277081a5948360c7b110897fdcac3 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Sat, 27 Feb 2021 18:20:56 -0800 Subject: [PATCH] Return a StateDiff object from record_fleet_state() --- nucypher/acumen/perception.py | 76 +++++++++++++++++++++-------------- 1 file changed, 45 insertions(+), 31 deletions(-) diff --git a/nucypher/acumen/perception.py b/nucypher/acumen/perception.py index 05acf5fdf..a7d54486f 100644 --- a/nucypher/acumen/perception.py +++ b/nucypher/acumen/perception.py @@ -46,6 +46,15 @@ class ArchivedFleetState(NamedTuple): population=self.population) +class StateDiff(NamedTuple): + this_node_updated: bool + nodes_updated: List[ChecksumAddress] + nodes_removed: List[ChecksumAddress] + + def empty(self): + return not self.this_node_updated and not self.nodes_updated and not self.nodes_removed + + class FleetState: """ Fleet state as perceived by a local Ursula. @@ -92,24 +101,28 @@ class FleetState: timestamp=self.timestamp, population=self.population) - def _remote_nodes_updated(self, - nodes_to_add: Iterable['Ursula'], - nodes_to_remove: Iterable[ChecksumAddress] - ) -> bool: + def _calculate_diff(self, + this_node_updated: bool, + nodes_to_add: Iterable['Ursula'], + nodes_to_remove: Iterable[ChecksumAddress] + ) -> StateDiff: + nodes_updated = [] for node in nodes_to_add: if node.checksum_address in nodes_to_remove: continue - if node.checksum_address not in self._nodes: - return True - if bytes(self._nodes[node.checksum_address]) != bytes(node): - return True + if ( node.checksum_address not in self._nodes + or bytes(self._nodes[node.checksum_address]) != bytes(node)): + nodes_updated.append(node.checksum_address) + nodes_removed = [] for checksum_address in nodes_to_remove: if checksum_address in self._nodes: - return True + nodes_removed.append(checksum_address) - return False + return StateDiff(this_node_updated=this_node_updated, + nodes_updated=nodes_updated, + nodes_removed=nodes_removed) def with_updated_nodes(self, nodes_to_add: Iterable['Ursula'], @@ -120,25 +133,25 @@ class FleetState: if self._this_node_ref is not None and not skip_this_node: this_node = self._this_node_ref() this_node_metadata = bytes(this_node) - this_node_changed = self._this_node_metadata != this_node_metadata + this_node_updated = self._this_node_metadata != this_node_metadata this_node_list = [this_node] else: this_node_metadata = self._this_node_metadata - this_node_changed = False + this_node_updated = False this_node_list = [] - remote_nodes_updated = self._remote_nodes_updated(nodes_to_add, nodes_to_remove) + diff = self._calculate_diff(this_node_updated, nodes_to_add, nodes_to_remove) - if this_node_changed or remote_nodes_updated: + if not diff.empty(): # TODO: if nodes were kept in a Merkle tree, # we'd have to only recalculate log(N) checksums. # Is it worth it? nodes = dict(self._nodes) - for node in nodes_to_add: - nodes[node.checksum_address] = node - for checksum_address in nodes_to_remove: - if checksum_address in nodes: - del nodes[checksum_address] + nodes_to_add_dict = {node.checksum_address: node for node in nodes_to_add} + for checksum_address in diff.nodes_updated: + nodes[checksum_address] = nodes_to_add_dict[checksum_address] + for checksum_address in diff.nodes_removed: + del nodes[checksum_address] all_nodes_sorted = sorted(itertools.chain(this_node_list, nodes.values()), key=lambda node: node.checksum_address) @@ -148,10 +161,12 @@ class FleetState: nodes = self._nodes checksum = self.checksum - return FleetState(checksum=checksum, - nodes=nodes, - this_node_ref=self._this_node_ref, - this_node_metadata=this_node_metadata) + new_state = FleetState(checksum=checksum, + nodes=nodes, + this_node_ref=self._this_node_ref, + this_node_metadata=this_node_metadata) + + return new_state, diff @property def population(self) -> int: @@ -329,10 +344,10 @@ class FleetSensor: def unpack_snapshot(data): return FleetState.unpack_snapshot(data) - def record_fleet_state(self, skip_this_node: bool = False) -> Optional[ArchivedFleetState]: - new_state = self._current_state.with_updated_nodes(nodes_to_add=self._nodes_to_add, - nodes_to_remove=self._nodes_to_remove, - skip_this_node=skip_this_node) + def record_fleet_state(self, skip_this_node: bool = False) -> StateDiff: + new_state, diff = self._current_state.with_updated_nodes(nodes_to_add=self._nodes_to_add, + nodes_to_remove=self._nodes_to_remove, + skip_this_node=skip_this_node) self._nodes_to_add = set() self._nodes_to_remove = set() @@ -343,12 +358,11 @@ class FleetSensor: # 1. (current) add a state to the archive every time it changes # 2. (possible) keep a dictionary of known states # and bump the timestamp of a previously encountered one - if new_state.checksum != self._archived_states[-1].checksum: + if not diff.empty(): archived_state = new_state.archived() self._archived_states.append(archived_state) - return archived_state - else: - return None + + return diff def shuffled(self): return self._current_state.shuffled()