Reduce public scope of rituals object on MockCoordinatorAgent. We don't really want anything changing state directly.

remotes/origin/v7.4.x
derekpierre 2024-01-29 12:19:56 -05:00 committed by Derek Pierre
parent 38c26d81aa
commit 6df1c037b7
3 changed files with 21 additions and 25 deletions

View File

@ -238,7 +238,7 @@ def test_ursula_ritualist(
def expired_ritual():
print("============ DKG DECRYPTION EXPIRED RITUAL =============")
ritual = mock_coordinator_agent.rituals[
ritual = mock_coordinator_agent._rituals[
ritual_id
] # if mocking state, use underlying object
time_in_past = mock_coordinator_agent.blockchain.get_blocktime() - 1

View File

@ -33,7 +33,7 @@ class MockCoordinatorAgent(MockContractAgent):
publicKey: Ferveo.G2Point
EVENTS = {}
rituals = []
_rituals = []
class Events(Enum):
START_RITUAL = 0
@ -77,7 +77,7 @@ class MockCoordinatorAgent(MockContractAgent):
#
def __find_participant_for_state_change(self, ritual_id, provider) -> Participant:
for p in self.rituals[ritual_id].participants:
for p in self._rituals[ritual_id].participants:
if p.provider == provider:
return p
@ -91,7 +91,7 @@ class MockCoordinatorAgent(MockContractAgent):
access_controller: ChecksumAddress,
transacting_power: TransactingPower,
) -> TxReceipt:
ritual_id = len(self.rituals)
ritual_id = len(self._rituals)
init_timestamp = int(time.time_ns())
end_timestamp = init_timestamp + duration
ritual = self.Ritual(
@ -108,7 +108,7 @@ class MockCoordinatorAgent(MockContractAgent):
dkg_size=len(providers),
threshold=self.get_threshold_for_ritual_size(len(providers)),
)
self.rituals.append(ritual)
self._rituals.append(ritual)
self.emit_event(
signal=self.Events.START_RITUAL,
ritual_id=ritual_id,
@ -123,7 +123,7 @@ class MockCoordinatorAgent(MockContractAgent):
transcript: Transcript,
transacting_power: TransactingPower,
) -> TxReceipt:
ritual = self.rituals[ritual_id]
ritual = self._rituals[ritual_id]
operator_address = transacting_power.account
# either mapping is populated or just assume provider same as operator for testing
provider = (
@ -152,7 +152,7 @@ class MockCoordinatorAgent(MockContractAgent):
participant_public_key: SessionStaticKey,
transacting_power: TransactingPower,
) -> TxReceipt:
ritual = self.rituals[ritual_id]
ritual = self._rituals[ritual_id]
operator_address = transacting_power.account
# either mapping is populated or just assume provider same as operator for testing
provider = (
@ -160,7 +160,7 @@ class MockCoordinatorAgent(MockContractAgent):
or transacting_power.account
)
participant = None
for p in self.rituals[ritual_id].participants:
for p in self._rituals[ritual_id].participants:
if p.provider == provider:
participant = p
break
@ -199,7 +199,7 @@ class MockCoordinatorAgent(MockContractAgent):
participant_keys.append(
self.ParticipantKey(
lastRitualId=len(self.rituals),
lastRitualId=self.number_of_rituals(),
publicKey=self.G2Point.from_public_key(public_key),
)
)
@ -217,16 +217,14 @@ class MockCoordinatorAgent(MockContractAgent):
return self.timeout
def number_of_rituals(self) -> int:
return len(self.rituals)
return len(self._rituals)
def get_ritual(
self, ritual_id: int, transcripts: bool = False, participants: bool = True
self, ritual_id: int, transcripts: bool = False
) -> Coordinator.Ritual:
ritual = self.rituals[ritual_id]
ritual = self._rituals[ritual_id]
# return a copy of the ritual object; the original value is used for state
copied_ritual = deepcopy(ritual)
if not participants:
return copied_ritual
if not transcripts:
for participant in copied_ritual.participants:
participant.transcript = bytes()
@ -242,7 +240,7 @@ class MockCoordinatorAgent(MockContractAgent):
def get_participant(
self, ritual_id: int, provider: ChecksumAddress, transcript: bool = False
) -> Participant:
for p in self.rituals[ritual_id].participants:
for p in self._rituals[ritual_id].participants:
if p.provider == provider:
copied_participant = deepcopy(p)
if not transcript:
@ -251,10 +249,10 @@ class MockCoordinatorAgent(MockContractAgent):
raise ValueError(f"Provider {provider} not found for ritual #{ritual_id}")
def get_providers(self, ritual_id: int) -> List[ChecksumAddress]:
return [p.provider for p in self.rituals[ritual_id].participants]
return [p.provider for p in self._rituals[ritual_id].participants]
def get_ritual_status(self, ritual_id: int) -> int:
ritual = self.rituals[ritual_id]
ritual = self._rituals[ritual_id]
timestamp = int(ritual.init_timestamp)
deadline = timestamp + self.timeout
if timestamp == 0:
@ -280,7 +278,7 @@ class MockCoordinatorAgent(MockContractAgent):
return result == self.RitualStatus.ACTIVE
def get_ritual_id_from_public_key(self, public_key: DkgPublicKey) -> int:
for i, ritual in enumerate(self.rituals):
for i, ritual in enumerate(self._rituals):
if bytes(ritual.public_key) == bytes(public_key):
return i

View File

@ -29,7 +29,7 @@ def coordinator():
def test_mock_coordinator_creation(coordinator):
assert len(coordinator.rituals) == 0
assert coordinator.number_of_rituals() == 0
def test_mock_coordinator_initiation(
@ -39,7 +39,7 @@ def test_mock_coordinator_initiation(
random_address,
get_random_checksum_address,
):
assert len(coordinator.rituals) == 0
assert coordinator.number_of_rituals() == 0
mock_transacting_power = mocker.Mock()
mock_transacting_power.account = random_address
coordinator.initiate_ritual(
@ -49,11 +49,9 @@ def test_mock_coordinator_initiation(
access_controller=get_random_checksum_address(),
transacting_power=mock_transacting_power,
)
assert len(coordinator.rituals) == 1
assert coordinator.number_of_rituals() == 1
ritual = coordinator.rituals[0]
ritual = coordinator._rituals[0]
assert len(ritual.participants) == DKG_SIZE
for p in ritual.participants:
assert p.transcript == bytes()
@ -71,7 +69,7 @@ def test_mock_coordinator_initiation(
def test_mock_coordinator_round_1(
nodes_transacting_powers, coordinator, random_transcript
):
ritual = coordinator.rituals[0]
ritual = coordinator._rituals[0]
assert (
coordinator.get_ritual_status(0)
== Coordinator.RitualStatus.DKG_AWAITING_TRANSCRIPTS
@ -108,7 +106,7 @@ def test_mock_coordinator_round_2(
dkg_public_key,
random_transcript,
):
ritual = coordinator.rituals[0]
ritual = coordinator._rituals[0]
assert (
coordinator.get_ritual_status(0)
== Coordinator.RitualStatus.DKG_AWAITING_AGGREGATIONS