Return a StateDiff object from record_fleet_state()

pull/2574/head
Bogdan Opanchuk 2021-02-27 18:20:56 -08:00
parent b20a6da134
commit bb94718514
1 changed files with 45 additions and 31 deletions

View File

@ -46,6 +46,15 @@ class ArchivedFleetState(NamedTuple):
population=self.population)
class StateDiff(NamedTuple):
this_node_updated: bool
nodes_updated: List[ChecksumAddress]
nodes_removed: List[ChecksumAddress]
def empty(self):
return not self.this_node_updated and not self.nodes_updated and not self.nodes_removed
class FleetState:
"""
Fleet state as perceived by a local Ursula.
@ -92,24 +101,28 @@ class FleetState:
timestamp=self.timestamp,
population=self.population)
def _remote_nodes_updated(self,
nodes_to_add: Iterable['Ursula'],
nodes_to_remove: Iterable[ChecksumAddress]
) -> bool:
def _calculate_diff(self,
this_node_updated: bool,
nodes_to_add: Iterable['Ursula'],
nodes_to_remove: Iterable[ChecksumAddress]
) -> StateDiff:
nodes_updated = []
for node in nodes_to_add:
if node.checksum_address in nodes_to_remove:
continue
if node.checksum_address not in self._nodes:
return True
if bytes(self._nodes[node.checksum_address]) != bytes(node):
return True
if ( node.checksum_address not in self._nodes
or bytes(self._nodes[node.checksum_address]) != bytes(node)):
nodes_updated.append(node.checksum_address)
nodes_removed = []
for checksum_address in nodes_to_remove:
if checksum_address in self._nodes:
return True
nodes_removed.append(checksum_address)
return False
return StateDiff(this_node_updated=this_node_updated,
nodes_updated=nodes_updated,
nodes_removed=nodes_removed)
def with_updated_nodes(self,
nodes_to_add: Iterable['Ursula'],
@ -120,25 +133,25 @@ class FleetState:
if self._this_node_ref is not None and not skip_this_node:
this_node = self._this_node_ref()
this_node_metadata = bytes(this_node)
this_node_changed = self._this_node_metadata != this_node_metadata
this_node_updated = self._this_node_metadata != this_node_metadata
this_node_list = [this_node]
else:
this_node_metadata = self._this_node_metadata
this_node_changed = False
this_node_updated = False
this_node_list = []
remote_nodes_updated = self._remote_nodes_updated(nodes_to_add, nodes_to_remove)
diff = self._calculate_diff(this_node_updated, nodes_to_add, nodes_to_remove)
if this_node_changed or remote_nodes_updated:
if not diff.empty():
# TODO: if nodes were kept in a Merkle tree,
# we'd have to only recalculate log(N) checksums.
# Is it worth it?
nodes = dict(self._nodes)
for node in nodes_to_add:
nodes[node.checksum_address] = node
for checksum_address in nodes_to_remove:
if checksum_address in nodes:
del nodes[checksum_address]
nodes_to_add_dict = {node.checksum_address: node for node in nodes_to_add}
for checksum_address in diff.nodes_updated:
nodes[checksum_address] = nodes_to_add_dict[checksum_address]
for checksum_address in diff.nodes_removed:
del nodes[checksum_address]
all_nodes_sorted = sorted(itertools.chain(this_node_list, nodes.values()),
key=lambda node: node.checksum_address)
@ -148,10 +161,12 @@ class FleetState:
nodes = self._nodes
checksum = self.checksum
return FleetState(checksum=checksum,
nodes=nodes,
this_node_ref=self._this_node_ref,
this_node_metadata=this_node_metadata)
new_state = FleetState(checksum=checksum,
nodes=nodes,
this_node_ref=self._this_node_ref,
this_node_metadata=this_node_metadata)
return new_state, diff
@property
def population(self) -> int:
@ -329,10 +344,10 @@ class FleetSensor:
def unpack_snapshot(data):
return FleetState.unpack_snapshot(data)
def record_fleet_state(self, skip_this_node: bool = False) -> Optional[ArchivedFleetState]:
new_state = self._current_state.with_updated_nodes(nodes_to_add=self._nodes_to_add,
nodes_to_remove=self._nodes_to_remove,
skip_this_node=skip_this_node)
def record_fleet_state(self, skip_this_node: bool = False) -> StateDiff:
new_state, diff = self._current_state.with_updated_nodes(nodes_to_add=self._nodes_to_add,
nodes_to_remove=self._nodes_to_remove,
skip_this_node=skip_this_node)
self._nodes_to_add = set()
self._nodes_to_remove = set()
@ -343,12 +358,11 @@ class FleetSensor:
# 1. (current) add a state to the archive every time it changes
# 2. (possible) keep a dictionary of known states
# and bump the timestamp of a previously encountered one
if new_state.checksum != self._archived_states[-1].checksum:
if not diff.empty():
archived_state = new_state.archived()
self._archived_states.append(archived_state)
return archived_state
else:
return None
return diff
def shuffled(self):
return self._current_state.shuffled()