mirror of https://github.com/nucypher/nucypher.git
agent layer: isolates coordinator function interface exposure (abstracting page management to actors)
parent
a9e4a87e6c
commit
e5782c65f8
|
@ -135,8 +135,8 @@ class NucypherTokenActor(BaseActor):
|
|||
class Operator(BaseActor):
|
||||
READY_TIMEOUT = None # (None or 0) == indefinite
|
||||
READY_POLL_RATE = 120 # seconds
|
||||
|
||||
AGGREGATION_SUBMISSION_MAX_DELAY = 60
|
||||
LOG = Logger("operator")
|
||||
|
||||
class OperatorError(BaseActor.ActorError):
|
||||
"""Operator-specific errors."""
|
||||
|
|
|
@ -741,6 +741,19 @@ class CoordinatorAgent(EthereumContractAgent):
|
|||
|
||||
return participant_public_keys
|
||||
|
||||
@classmethod
|
||||
def make_participants(cls, data: list, start: int = 0) -> Iterable[Participant]:
|
||||
"""Converts a list of participant data into an iterable of Participant objects."""
|
||||
for i, participant_data in enumerate(data, start=start):
|
||||
participant = cls.Participant(
|
||||
index=i,
|
||||
provider=ChecksumAddress(participant_data[0]),
|
||||
aggregated=participant_data[1],
|
||||
transcript=bytes(participant_data[2]),
|
||||
decryption_request_static_key=bytes(participant_data[3]),
|
||||
)
|
||||
yield participant
|
||||
|
||||
@contract_api(CONTRACT_CALL)
|
||||
def get_timeout(self) -> int:
|
||||
return self.contract.functions.timeout().call()
|
||||
|
@ -767,22 +780,6 @@ class CoordinatorAgent(EthereumContractAgent):
|
|||
ritual.public_key = self.Ritual.G1Point(result[10][0], result[10][1])
|
||||
return ritual
|
||||
|
||||
def get_ritual(
|
||||
self,
|
||||
ritual_id: int,
|
||||
transcripts: bool = False,
|
||||
) -> Ritual:
|
||||
"""assembles Coordinator data using multiple RPC requests."""
|
||||
ritual = self.__rituals(ritual_id) # 1 rpc call
|
||||
|
||||
participants = self.get_participants(
|
||||
ritual_id=ritual_id,
|
||||
num_participants=ritual.dkg_size,
|
||||
include_transcripts=transcripts,
|
||||
) # 1 rpc call
|
||||
ritual.participants = participants
|
||||
return ritual
|
||||
|
||||
@contract_api(CONTRACT_CALL)
|
||||
def get_ritual_status(self, ritual_id: int) -> int:
|
||||
result = self.contract.functions.getRitualState(ritual_id).call()
|
||||
|
@ -798,6 +795,52 @@ class CoordinatorAgent(EthereumContractAgent):
|
|||
result = self.contract.functions.isParticipant(ritual_id, provider).call()
|
||||
return result
|
||||
|
||||
def _get_participants(
|
||||
self,
|
||||
ritual: Ritual,
|
||||
ritual_id: int,
|
||||
transcripts: bool,
|
||||
page_size: int = 10, # Default pagination size
|
||||
) -> Iterable[Ritual.Participant]:
|
||||
"""Fetches all participants for a given ritual."""
|
||||
start, end = 0, ritual.dkg_size
|
||||
while start < end:
|
||||
current_page_size = min(page_size, end - start)
|
||||
batch = self.get_participants(
|
||||
ritual_id=ritual_id,
|
||||
start=start,
|
||||
end=start + current_page_size,
|
||||
transcripts=transcripts,
|
||||
)
|
||||
for participant in batch:
|
||||
yield participant
|
||||
start += current_page_size
|
||||
|
||||
def get_ritual(
|
||||
self,
|
||||
ritual_id: int,
|
||||
transcripts: bool = False,
|
||||
participants: bool = True,
|
||||
) -> Ritual:
|
||||
"""
|
||||
Exposes three views of Coordinator.Rituals:
|
||||
1. The ritual metadata only
|
||||
2. ritual + participant metadata
|
||||
3. ritual + participants + transcripts
|
||||
"""
|
||||
ritual = self.__rituals(ritual_id)
|
||||
if participants:
|
||||
ritual.participants = list(
|
||||
self._get_participants(
|
||||
ritual=ritual,
|
||||
ritual_id=ritual_id,
|
||||
transcripts=transcripts,
|
||||
)
|
||||
)
|
||||
elif transcripts:
|
||||
raise ValueError("Cannot fetch transcripts without participants")
|
||||
return ritual
|
||||
|
||||
@contract_api(CONTRACT_CALL)
|
||||
def get_participant(
|
||||
self, ritual_id: int, provider: ChecksumAddress, transcript: bool
|
||||
|
@ -805,55 +848,24 @@ class CoordinatorAgent(EthereumContractAgent):
|
|||
data, index = self.contract.functions.getParticipant(
|
||||
ritual_id, provider, transcript
|
||||
).call()
|
||||
participant = self.Ritual.Participant(
|
||||
index=index,
|
||||
provider=ChecksumAddress(data[0]),
|
||||
aggregated=data[1],
|
||||
transcript=bytes(data[2]),
|
||||
decryption_request_static_key=bytes(data[3]),
|
||||
)
|
||||
participant = next(iter(self.Ritual.make_participants([data], start=index)))
|
||||
return participant
|
||||
|
||||
@contract_api(CONTRACT_CALL)
|
||||
def get_participants(
|
||||
self,
|
||||
ritual_id: int,
|
||||
num_participants: int,
|
||||
pagination_size: Optional[int] = None,
|
||||
include_transcripts: Optional[bool] = False,
|
||||
start: Optional[int],
|
||||
end: Optional[int],
|
||||
transcripts: Optional[bool] = False,
|
||||
) -> List[Ritual.Participant]:
|
||||
ritual_participants = list()
|
||||
|
||||
if num_participants <= 0:
|
||||
raise ValueError("Number of participants must be >= 0")
|
||||
if pagination_size is None:
|
||||
pagination_size = 10 # TODO what's a good value here?
|
||||
self.log.debug(
|
||||
f"Defaulting to pagination size of {pagination_size} for participants"
|
||||
)
|
||||
elif pagination_size <= 0:
|
||||
raise ValueError("Pagination size must be > 0")
|
||||
|
||||
# TODO: another option is if include_transcripts is false, then don't bother paginating...?
|
||||
|
||||
if pagination_size > 0:
|
||||
start_index: int = 0
|
||||
while start_index < num_participants:
|
||||
result = self.contract.functions.getParticipants(
|
||||
ritual_id, start_index, pagination_size, include_transcripts
|
||||
).call()
|
||||
for i, participant_data in enumerate(result):
|
||||
participant = self.Ritual.Participant(
|
||||
index=start_index + i,
|
||||
provider=ChecksumAddress(participant_data[0]),
|
||||
aggregated=participant_data[1],
|
||||
transcript=bytes(participant_data[2]),
|
||||
decryption_request_static_key=bytes(participant_data[3]),
|
||||
)
|
||||
ritual_participants.append(participant)
|
||||
start_index += pagination_size
|
||||
|
||||
return ritual_participants
|
||||
if end < start:
|
||||
raise ValueError("End must be greater than or equal to start")
|
||||
data = self.contract.functions.getParticipants(
|
||||
ritual_id, start, end - start, transcripts
|
||||
).call()
|
||||
participants = self.Ritual.make_participants(data, start=start)
|
||||
return list(participants)
|
||||
|
||||
@contract_api(CONTRACT_CALL)
|
||||
def get_provider_public_key(
|
||||
|
|
|
@ -222,16 +222,16 @@ class MockCoordinatorAgent(MockContractAgent):
|
|||
return len(self.rituals)
|
||||
|
||||
def get_ritual(
|
||||
self, ritual_id: int, transcripts: bool = False
|
||||
self, ritual_id: int, transcripts: bool = False, participants: bool = True
|
||||
) -> CoordinatorAgent.Ritual:
|
||||
ritual = self.rituals[ritual_id]
|
||||
|
||||
# return a copy of the ritual object; the original value is used for state
|
||||
copied_ritual = deepcopy(ritual)
|
||||
if not participants:
|
||||
return copied_ritual
|
||||
if not transcripts:
|
||||
for participant in copied_ritual.participants:
|
||||
participant.transcript = bytes()
|
||||
|
||||
return copied_ritual
|
||||
|
||||
def is_participant(self, ritual_id: int, provider: ChecksumAddress) -> bool:
|
||||
|
|
Loading…
Reference in New Issue