agent layer: isolates coordinator function interface exposure (abstracting page management to actors)

remotes/origin/v7.4.x
KPrasch 2024-01-28 18:40:13 +01:00 committed by Derek Pierre
parent a9e4a87e6c
commit e5782c65f8
3 changed files with 74 additions and 62 deletions

View File

@ -135,8 +135,8 @@ class NucypherTokenActor(BaseActor):
class Operator(BaseActor):
READY_TIMEOUT = None # (None or 0) == indefinite
READY_POLL_RATE = 120 # seconds
AGGREGATION_SUBMISSION_MAX_DELAY = 60
LOG = Logger("operator")
class OperatorError(BaseActor.ActorError):
"""Operator-specific errors."""

View File

@ -741,6 +741,19 @@ class CoordinatorAgent(EthereumContractAgent):
return participant_public_keys
@classmethod
def make_participants(cls, data: list, start: int = 0) -> Iterable[Participant]:
"""Converts a list of participant data into an iterable of Participant objects."""
for i, participant_data in enumerate(data, start=start):
participant = cls.Participant(
index=i,
provider=ChecksumAddress(participant_data[0]),
aggregated=participant_data[1],
transcript=bytes(participant_data[2]),
decryption_request_static_key=bytes(participant_data[3]),
)
yield participant
@contract_api(CONTRACT_CALL)
def get_timeout(self) -> int:
return self.contract.functions.timeout().call()
@ -767,22 +780,6 @@ class CoordinatorAgent(EthereumContractAgent):
ritual.public_key = self.Ritual.G1Point(result[10][0], result[10][1])
return ritual
def get_ritual(
self,
ritual_id: int,
transcripts: bool = False,
) -> Ritual:
"""assembles Coordinator data using multiple RPC requests."""
ritual = self.__rituals(ritual_id) # 1 rpc call
participants = self.get_participants(
ritual_id=ritual_id,
num_participants=ritual.dkg_size,
include_transcripts=transcripts,
) # 1 rpc call
ritual.participants = participants
return ritual
@contract_api(CONTRACT_CALL)
def get_ritual_status(self, ritual_id: int) -> int:
result = self.contract.functions.getRitualState(ritual_id).call()
@ -798,6 +795,52 @@ class CoordinatorAgent(EthereumContractAgent):
result = self.contract.functions.isParticipant(ritual_id, provider).call()
return result
def _get_participants(
self,
ritual: Ritual,
ritual_id: int,
transcripts: bool,
page_size: int = 10, # Default pagination size
) -> Iterable[Ritual.Participant]:
"""Fetches all participants for a given ritual."""
start, end = 0, ritual.dkg_size
while start < end:
current_page_size = min(page_size, end - start)
batch = self.get_participants(
ritual_id=ritual_id,
start=start,
end=start + current_page_size,
transcripts=transcripts,
)
for participant in batch:
yield participant
start += current_page_size
def get_ritual(
self,
ritual_id: int,
transcripts: bool = False,
participants: bool = True,
) -> Ritual:
"""
Exposes three views of Coordinator.Rituals:
1. The ritual metadata only
2. ritual + participant metadata
3. ritual + participants + transcripts
"""
ritual = self.__rituals(ritual_id)
if participants:
ritual.participants = list(
self._get_participants(
ritual=ritual,
ritual_id=ritual_id,
transcripts=transcripts,
)
)
elif transcripts:
raise ValueError("Cannot fetch transcripts without participants")
return ritual
@contract_api(CONTRACT_CALL)
def get_participant(
self, ritual_id: int, provider: ChecksumAddress, transcript: bool
@ -805,55 +848,24 @@ class CoordinatorAgent(EthereumContractAgent):
data, index = self.contract.functions.getParticipant(
ritual_id, provider, transcript
).call()
participant = self.Ritual.Participant(
index=index,
provider=ChecksumAddress(data[0]),
aggregated=data[1],
transcript=bytes(data[2]),
decryption_request_static_key=bytes(data[3]),
)
participant = next(iter(self.Ritual.make_participants([data], start=index)))
return participant
@contract_api(CONTRACT_CALL)
def get_participants(
self,
ritual_id: int,
num_participants: int,
pagination_size: Optional[int] = None,
include_transcripts: Optional[bool] = False,
start: Optional[int],
end: Optional[int],
transcripts: Optional[bool] = False,
) -> List[Ritual.Participant]:
ritual_participants = list()
if num_participants <= 0:
raise ValueError("Number of participants must be >= 0")
if pagination_size is None:
pagination_size = 10 # TODO what's a good value here?
self.log.debug(
f"Defaulting to pagination size of {pagination_size} for participants"
)
elif pagination_size <= 0:
raise ValueError("Pagination size must be > 0")
# TODO: another option is if include_transcripts is false, then don't bother paginating...?
if pagination_size > 0:
start_index: int = 0
while start_index < num_participants:
result = self.contract.functions.getParticipants(
ritual_id, start_index, pagination_size, include_transcripts
).call()
for i, participant_data in enumerate(result):
participant = self.Ritual.Participant(
index=start_index + i,
provider=ChecksumAddress(participant_data[0]),
aggregated=participant_data[1],
transcript=bytes(participant_data[2]),
decryption_request_static_key=bytes(participant_data[3]),
)
ritual_participants.append(participant)
start_index += pagination_size
return ritual_participants
if end < start:
raise ValueError("End must be greater than or equal to start")
data = self.contract.functions.getParticipants(
ritual_id, start, end - start, transcripts
).call()
participants = self.Ritual.make_participants(data, start=start)
return list(participants)
@contract_api(CONTRACT_CALL)
def get_provider_public_key(

View File

@ -222,16 +222,16 @@ class MockCoordinatorAgent(MockContractAgent):
return len(self.rituals)
def get_ritual(
self, ritual_id: int, transcripts: bool = False
self, ritual_id: int, transcripts: bool = False, participants: bool = True
) -> CoordinatorAgent.Ritual:
ritual = self.rituals[ritual_id]
# return a copy of the ritual object; the original value is used for state
copied_ritual = deepcopy(ritual)
if not participants:
return copied_ritual
if not transcripts:
for participant in copied_ritual.participants:
participant.transcript = bytes()
return copied_ritual
def is_participant(self, ritual_id: int, provider: ChecksumAddress) -> bool: